diff --git a/config/config2.example.yaml b/config/config2.example.yaml index ba480d984..a24892c2a 100644 --- a/config/config2.example.yaml +++ b/config/config2.example.yaml @@ -79,6 +79,8 @@ exp_pool: enable_read: false enable_write: false persist_path: .chroma_exp_data # The directory. + retrieval_type: bm25 # Default is `bm25`, can be set to `chroma` for vector storage, which requires setting up embedding. + use_llm_ranker: false # If `use_llm_ranker` is true, then it will use LLM Reranker to get better result, but it is not always guaranteed that the output will be parseable for reranking. azure_tts_subscription_key: "YOUR_SUBSCRIPTION_KEY" azure_tts_region: "eastus" diff --git a/examples/exp_pool/README.md b/examples/exp_pool/README.md index d0b49f2ad..37e7853f8 100644 --- a/examples/exp_pool/README.md +++ b/examples/exp_pool/README.md @@ -3,7 +3,7 @@ # Experience Pool ## Prerequisites - Ensure the RAG module is installed: https://docs.deepwisdom.ai/main/en/guide/in_depth_guides/rag_module.html - Set embedding: https://docs.deepwisdom.ai/main/en/guide/in_depth_guides/rag_module.html -- Set both `enable_read` and `enable_write` to `true` in the `exp_pool` section of `config2.yaml` +- Set `enabled`、`enable_read` and `enable_write` to `true` in the `exp_pool` section of `config2.yaml` ## Example Files diff --git a/metagpt/actions/analyze_requirements.py b/metagpt/actions/analyze_requirements.py index ac412b1d1..d81da3e14 100644 --- a/metagpt/actions/analyze_requirements.py +++ b/metagpt/actions/analyze_requirements.py @@ -24,7 +24,7 @@ Requirements: 创建一个贪吃蛇,只需要给出设计文档和代码 Outputs: [User Restrictions] : 只需要给出设计文档和代码. -[Language Restrictions] : The response, message and instruction must be in the language specified by Chinese. +[Language Restrictions] : The response, message and instruction must be in Chinese. [Programming Language] : HTML (*.html), CSS (*.css), and JavaScript (*.js) Example 2 @@ -32,7 +32,7 @@ Requirements: Create 2048 game using Python. Do not write PRD. Outputs: [User Restrictions] : Do not write PRD. -[Language Restrictions] : The response, message and instruction must be in the language specified by English. +[Language Restrictions] : The response, message and instruction must be in English. [Programming Language] : Python Example 3 @@ -40,7 +40,7 @@ Requirements: You must ignore create PRD and TRD. Help me write a schedule display program for the Paris Olympics. Outputs: [User Restrictions] : You must ignore create PRD and TRD. -[Language Restrictions] : The response, message and instruction must be in the language specified by English. +[Language Restrictions] : The response, message and instruction must be in English. [Programming Language] : HTML (*.html), CSS (*.css), and JavaScript (*.js) """ @@ -57,7 +57,7 @@ Note: OUTPUT_FORMAT = """ [User Restrictions] : the restrictions in the requirements -[Language Restrictions] : The response, message and instruction must be in the language specified by {{language}} +[Language Restrictions] : The response, message and instruction must be in {{language}} [Programming Language] : Your program must use ... """ diff --git a/metagpt/actions/design_api.py b/metagpt/actions/design_api.py index 4476e7c0a..86fa699bb 100644 --- a/metagpt/actions/design_api.py +++ b/metagpt/actions/design_api.py @@ -245,6 +245,7 @@ class WriteDesign(Action): ) -> str: prd_content = "" if prd_filename: + prd_filename = rectify_pathname(path=prd_filename, default_filename="prd.json") prd_content = await aread(filename=prd_filename) context = "### User Requirements\n{user_requirement}\n### Extra_info\n{extra_info}\n### PRD\n{prd}\n".format( user_requirement=to_markdown_code_block(user_requirement), diff --git a/metagpt/actions/project_management.py b/metagpt/actions/project_management.py index 2d54ffe08..abfea7f10 100644 --- a/metagpt/actions/project_management.py +++ b/metagpt/actions/project_management.py @@ -180,6 +180,7 @@ class WriteTasks(Action): ) -> str: context = to_markdown_code_block(user_requirement) if design_filename: + design_filename = rectify_pathname(path=design_filename, default_filename="system_design.json") content = await aread(filename=design_filename) context += to_markdown_code_block(content) diff --git a/metagpt/configs/exp_pool_config.py b/metagpt/configs/exp_pool_config.py index e2872179f..7611dda27 100644 --- a/metagpt/configs/exp_pool_config.py +++ b/metagpt/configs/exp_pool_config.py @@ -1,8 +1,15 @@ +from enum import Enum + from pydantic import Field from metagpt.utils.yaml_model import YamlModel +class ExperiencePoolRetrievalType(Enum): + BM25 = "bm25" + CHROMA = "chroma" + + class ExperiencePoolConfig(YamlModel): enabled: bool = Field( default=False, @@ -11,3 +18,7 @@ class ExperiencePoolConfig(YamlModel): enable_read: bool = Field(default=False, description="Enable to read from experience pool.") enable_write: bool = Field(default=False, description="Enable to write to experience pool.") persist_path: str = Field(default=".chroma_exp_data", description="The persist path for experience pool.") + retrieval_type: ExperiencePoolRetrievalType = Field( + default=ExperiencePoolRetrievalType.BM25, description="The retrieval type for experience pool." + ) + use_llm_ranker: bool = Field(default=False, description="Use LLM Reranker to get better result.") diff --git a/metagpt/exp_pool/decorator.py b/metagpt/exp_pool/decorator.py index bb3d434c0..9b2cf3474 100644 --- a/metagpt/exp_pool/decorator.py +++ b/metagpt/exp_pool/decorator.py @@ -134,14 +134,14 @@ class ExpCacheHandler(BaseModel): """Fetch experiences by query_type.""" self._exps = await self.exp_manager.query_exps(self._req, query_type=self.query_type, tag=self.tag) - logger.debug(f"Found {len(self._exps)} experiences for req '{self._req[:20]}...' and tag '{self.tag}'") + logger.info(f"Found {len(self._exps)} experiences for tag '{self.tag}'") async def get_one_perfect_exp(self) -> Optional[Any]: """Get a potentially perfect experience, and resolve resp.""" for exp in self._exps: if await self.exp_perfect_judge.is_perfect_exp(exp, self._req, *self.args, **self.kwargs): - logger.debug(f"Got one perfect experience for req '{exp.req[:20]}...'") + logger.info(f"Got one perfect experience for req '{exp.req[:20]}...'") return self.serializer.deserialize_resp(exp.resp) return None diff --git a/metagpt/exp_pool/manager.py b/metagpt/exp_pool/manager.py index 5f4d71edc..e38906d90 100644 --- a/metagpt/exp_pool/manager.py +++ b/metagpt/exp_pool/manager.py @@ -1,10 +1,12 @@ """Experience Manager.""" +from pathlib import Path from typing import TYPE_CHECKING, Any from pydantic import BaseModel, ConfigDict, Field from metagpt.config2 import Config +from metagpt.configs.exp_pool_config import ExperiencePoolRetrievalType from metagpt.exp_pool.schema import ( DEFAULT_COLLECTION_NAME, DEFAULT_SIMILARITY_TOP_K, @@ -15,7 +17,7 @@ from metagpt.logs import logger from metagpt.utils.exceptions import handle_exception if TYPE_CHECKING: - from llama_index.vector_stores.chroma import ChromaVectorStore + from metagpt.rag.engines import SimpleEngine class ExperienceManager(BaseModel): @@ -32,40 +34,16 @@ class ExperienceManager(BaseModel): config: Config = Field(default_factory=Config.default) _storage: Any = None - _vector_store: Any = None @property def storage(self): if self._storage is None: - try: - from metagpt.rag.engines import SimpleEngine - from metagpt.rag.schema import ChromaRetrieverConfig, LLMRankerConfig - except ImportError: - raise ImportError("To use the experience pool, you need to install the rag module.") - - retriever_configs = [ - ChromaRetrieverConfig( - persist_path=self.config.exp_pool.persist_path, - collection_name=DEFAULT_COLLECTION_NAME, - similarity_top_k=DEFAULT_SIMILARITY_TOP_K, - ) - ] - ranker_configs = [LLMRankerConfig(top_n=DEFAULT_SIMILARITY_TOP_K)] - - self._storage: SimpleEngine = SimpleEngine.from_objs( - retriever_configs=retriever_configs, ranker_configs=ranker_configs - ) logger.info(f"exp_pool config: {self.config.exp_pool}") + self._storage = self._resolve_storage() + return self._storage - @property - def vector_store(self): - if not self._vector_store: - self._vector_store: ChromaVectorStore = self.storage._retriever._vector_store - - return self._vector_store - @handle_exception def create_exp(self, exp: Experience): """Adds an experience to the storage if writing is enabled. @@ -78,6 +56,7 @@ class ExperienceManager(BaseModel): return self.storage.add_objs([exp]) + self.storage.persist(self.config.exp_pool.persist_path) @handle_exception(default_return=[]) async def query_exps(self, req: str, tag: str = "", query_type: QueryType = QueryType.SEMANTIC) -> list[Experience]: @@ -110,7 +89,106 @@ class ExperienceManager(BaseModel): def get_exps_count(self) -> int: """Get the total number of experiences.""" - return self.vector_store._collection.count() + return self.storage.count() + + def _resolve_storage(self) -> "SimpleEngine": + """Selects the appropriate storage creation method based on the configured retrieval type.""" + + storage_creators = { + ExperiencePoolRetrievalType.BM25: self._create_bm25_storage, + ExperiencePoolRetrievalType.CHROMA: self._create_chroma_storage, + } + + return storage_creators[self.config.exp_pool.retrieval_type]() + + def _create_bm25_storage(self) -> "SimpleEngine": + """Creates or loads BM25 storage. + + This function attempts to create a new BM25 storage if the specified + document store path does not exist. If the path exists, it loads the + existing BM25 storage. + + Returns: + SimpleEngine: An instance of SimpleEngine configured with BM25 storage. + + Raises: + ImportError: If required modules are not installed. + """ + + try: + from metagpt.rag.engines import SimpleEngine + from metagpt.rag.schema import BM25IndexConfig, BM25RetrieverConfig + except ImportError: + raise ImportError("To use the experience pool, you need to install the rag module.") + + persist_path = Path(self.config.exp_pool.persist_path) + docstore_path = persist_path / "docstore.json" + + ranker_configs = self._get_ranker_configs() + + if not docstore_path.exists(): + logger.debug(f"Path `{docstore_path}` not exists, try to create a new bm25 storage.") + exps = [Experience(req="req", resp="resp")] + + retriever_configs = [BM25RetrieverConfig(create_index=True, similarity_top_k=DEFAULT_SIMILARITY_TOP_K)] + + storage = SimpleEngine.from_objs( + objs=exps, retriever_configs=retriever_configs, ranker_configs=ranker_configs + ) + return storage + + logger.debug(f"Path `{docstore_path}` exists, try to load bm25 storage.") + retriever_configs = [BM25RetrieverConfig(similarity_top_k=DEFAULT_SIMILARITY_TOP_K)] + storage = SimpleEngine.from_index( + BM25IndexConfig(persist_path=persist_path), + retriever_configs=retriever_configs, + ranker_configs=ranker_configs, + ) + + return storage + + def _create_chroma_storage(self) -> "SimpleEngine": + """Creates Chroma storage. + + Returns: + SimpleEngine: An instance of SimpleEngine configured with Chroma storage. + + Raises: + ImportError: If required modules are not installed. + """ + + try: + from metagpt.rag.engines import SimpleEngine + from metagpt.rag.schema import ChromaRetrieverConfig + except ImportError: + raise ImportError("To use the experience pool, you need to install the rag module.") + + retriever_configs = [ + ChromaRetrieverConfig( + persist_path=self.config.exp_pool.persist_path, + collection_name=DEFAULT_COLLECTION_NAME, + similarity_top_k=DEFAULT_SIMILARITY_TOP_K, + ) + ] + ranker_configs = self._get_ranker_configs() + + storage = SimpleEngine.from_objs(retriever_configs=retriever_configs, ranker_configs=ranker_configs) + + return storage + + def _get_ranker_configs(self): + """Returns ranker configurations based on the configuration. + + If `use_llm_ranker` is True, returns a list with one `LLMRankerConfig` + instance. Otherwise, returns an empty list. + + Returns: + list: A list of `LLMRankerConfig` instances or an empty list. + """ + + from metagpt.rag.schema import LLMRankerConfig + + return [LLMRankerConfig(top_n=DEFAULT_SIMILARITY_TOP_K)] if self.config.exp_pool.use_llm_ranker else [] _exp_manager = None diff --git a/metagpt/prompts/di/team_leader.py b/metagpt/prompts/di/team_leader.py index d7eb33442..8d4f55ae5 100644 --- a/metagpt/prompts/di/team_leader.py +++ b/metagpt/prompts/di/team_leader.py @@ -16,7 +16,8 @@ Note: 1. If the requirement is a pure DATA-RELATED requirement, such as web browsing, web scraping, web searching, web imitation, data science, data analysis, machine learning, deep learning, text-to-image etc. DON'T decompose it, assign a single task with the original user requirement as instruction directly to Data Analyst. 2. If the requirement is developing a software, game, app, or website, excluding the above data-related tasks, you should decompose the requirement into multiple tasks and assign them to different team members based on their expertise. The software default development process has four steps: creating a Product Requirement Document (PRD) by the Product Manager -> writing a System Design by the Architect -> creating tasks by the Project Manager -> and coding by the Engineer. You may choose to execute any of these steps. When publishing message to Product Manager, you should directly copy the full original user requirement. 2.1. If the requirement contains both DATA-RELATED part mentioned in 1 and software development part mentioned in 2, you should decompose the software development part and assign them to different team members based on their expertise, and assign the DATA-RELATED part to Data Analyst David directly. -3. If the requirement is to fix a bug or issue, you should assign it to Issue Solver instead of Engineer. However, if the bug or issue is related to the software developed by the team, you should assign it to Engineer. +3.1 If the task involves code review or code checking, you should assign it to Engineer. +3.2. If the requirement is to fix a bug or issue, you should assign it to Issue Solver. However, if the code is written by Engineer, Engineer must maintain the code. 4. If the requirement is a common-sense, logical, or math problem, you should respond directly without assigning any task to team members. 5. If you think the requirement is not clear or ambiguous, you should ask the user for clarification immediately. Assign tasks only after all info is clear. 6. It is helpful for Engineer to have both the system design and the project schedule for writing the code, so include paths of both files (if available) and remind Engineer to definitely read them when publishing message to Engineer. diff --git a/metagpt/rag/engines/simple.py b/metagpt/rag/engines/simple.py index 8a9ccaffd..be4c3daf5 100644 --- a/metagpt/rag/engines/simple.py +++ b/metagpt/rag/engines/simple.py @@ -37,7 +37,11 @@ from metagpt.rag.factories import ( get_retriever, ) from metagpt.rag.interface import NoEmbedding, RAGObject -from metagpt.rag.retrievers.base import ModifiableRAGRetriever, PersistableRAGRetriever +from metagpt.rag.retrievers.base import ( + ModifiableRAGRetriever, + PersistableRAGRetriever, + QueryableRAGRetriever, +) from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever from metagpt.rag.schema import ( BaseIndexConfig, @@ -144,7 +148,7 @@ class SimpleEngine(RetrieverQueryEngine): if not objs and any(isinstance(config, BM25RetrieverConfig) for config in retriever_configs): raise ValueError("In BM25RetrieverConfig, Objs must not be empty.") - nodes = [ObjectNode(text=obj.rag_key(), metadata=ObjectNode.get_obj_metadata(obj)) for obj in objs] + nodes = cls.get_obj_nodes(objs) return cls._from_nodes( nodes=nodes, @@ -201,7 +205,7 @@ class SimpleEngine(RetrieverQueryEngine): """Adds objects to the retriever, storing each object's original form in metadata for future reference.""" self._ensure_retriever_modifiable() - nodes = [ObjectNode(text=obj.rag_key(), metadata=ObjectNode.get_obj_metadata(obj)) for obj in objs] + nodes = self.get_obj_nodes(objs) self._save_nodes(nodes) def persist(self, persist_dir: Union[str, os.PathLike], **kwargs): @@ -210,6 +214,18 @@ class SimpleEngine(RetrieverQueryEngine): self._persist(str(persist_dir), **kwargs) + def count(self) -> int: + """Count.""" + self._ensure_retriever_queryable() + + return self._retriever.query_total_count() + + @staticmethod + def get_obj_nodes(objs: Optional[list[RAGObject]] = None) -> list[ObjectNode]: + """Converts a list of RAGObjects to a list of ObjectNodes.""" + + return [ObjectNode(text=obj.rag_key(), metadata=ObjectNode.get_obj_metadata(obj)) for obj in objs] + @classmethod def _from_nodes( cls, @@ -258,6 +274,9 @@ class SimpleEngine(RetrieverQueryEngine): def _ensure_retriever_persistable(self): self._ensure_retriever_of_type(PersistableRAGRetriever) + def _ensure_retriever_queryable(self): + self._ensure_retriever_of_type(QueryableRAGRetriever) + def _ensure_retriever_of_type(self, required_type: BaseRetriever): """Ensure that self.retriever is required_type, or at least one of its components, if it's a SimpleHybridRetriever. diff --git a/metagpt/rag/factories/retriever.py b/metagpt/rag/factories/retriever.py index 1460e131b..6bc8e4ad5 100644 --- a/metagpt/rag/factories/retriever.py +++ b/metagpt/rag/factories/retriever.py @@ -7,6 +7,7 @@ import chromadb import faiss from llama_index.core import StorageContext, VectorStoreIndex from llama_index.core.embeddings import BaseEmbedding +from llama_index.core.embeddings.mock_embed_model import MockEmbedding from llama_index.core.schema import BaseNode from llama_index.core.vector_stores.types import BasePydanticVectorStore from llama_index.vector_stores.chroma import ChromaVectorStore @@ -85,6 +86,12 @@ class RetrieverFactory(ConfigBasedFactory): index = self._extract_index(config, **kwargs) nodes = list(index.docstore.docs.values()) if index else self._extract_nodes(config, **kwargs) + if index and not config.index: + config.index = index + + if not config.index and config.create_index: + config.index = VectorStoreIndex(nodes, embed_model=MockEmbedding(embed_dim=1)) + return DynamicBM25Retriever(nodes=nodes, **config.model_dump()) def _create_chroma_retriever(self, config: ChromaRetrieverConfig, **kwargs) -> ChromaRetriever: diff --git a/metagpt/rag/retrievers/base.py b/metagpt/rag/retrievers/base.py index a7b836833..5bd04adca 100644 --- a/metagpt/rag/retrievers/base.py +++ b/metagpt/rag/retrievers/base.py @@ -45,3 +45,17 @@ class PersistableRAGRetriever(RAGRetriever): @abstractmethod def persist(self, persist_dir: str, **kwargs) -> None: """To support persist, must inplement this func""" + + +class QueryableRAGRetriever(RAGRetriever): + """Support querying total count.""" + + @classmethod + def __subclasshook__(cls, C): + if cls is QueryableRAGRetriever: + return check_methods(C, "query_total_count") + return NotImplemented + + @abstractmethod + def query_total_count(self) -> int: + """To support querying total count, must implement this func""" diff --git a/metagpt/rag/retrievers/bm25_retriever.py b/metagpt/rag/retrievers/bm25_retriever.py index dc75d87b0..ace1bb86c 100644 --- a/metagpt/rag/retrievers/bm25_retriever.py +++ b/metagpt/rag/retrievers/bm25_retriever.py @@ -47,3 +47,8 @@ class DynamicBM25Retriever(BM25Retriever): """Support persist.""" if self._index: self._index.storage_context.persist(persist_dir) + + def query_total_count(self) -> int: + """Support query total count.""" + + return len(self._nodes) diff --git a/metagpt/rag/retrievers/chroma_retriever.py b/metagpt/rag/retrievers/chroma_retriever.py index d41f375e4..6c466e49f 100644 --- a/metagpt/rag/retrievers/chroma_retriever.py +++ b/metagpt/rag/retrievers/chroma_retriever.py @@ -2,6 +2,7 @@ from llama_index.core.retrievers import VectorIndexRetriever from llama_index.core.schema import BaseNode +from llama_index.vector_stores.chroma import ChromaVectorStore class ChromaRetriever(VectorIndexRetriever): @@ -15,3 +16,10 @@ class ChromaRetriever(VectorIndexRetriever): """Support persist. Chromadb automatically saves, so there is no need to implement.""" + + def query_total_count(self) -> int: + """Support query total count.""" + + vector_store: ChromaVectorStore = self._vector_store + + return vector_store._collection.count() diff --git a/metagpt/rag/schema.py b/metagpt/rag/schema.py index 5e97e60c3..5be2b050b 100644 --- a/metagpt/rag/schema.py +++ b/metagpt/rag/schema.py @@ -60,6 +60,11 @@ class FAISSRetrieverConfig(IndexRetrieverConfig): class BM25RetrieverConfig(IndexRetrieverConfig): """Config for BM25-based retrievers.""" + create_index: bool = Field( + default=False, + description="Indicates whether to create an index for the nodes. It is useful when you need to persist data while only using BM25.", + exclude=True, + ) _no_embedding: bool = PrivateAttr(default=True) diff --git a/metagpt/roles/architect.py b/metagpt/roles/architect.py index 1af169ca1..ccce75afa 100644 --- a/metagpt/roles/architect.py +++ b/metagpt/roles/architect.py @@ -11,7 +11,7 @@ from metagpt.roles.di.role_zero import RoleZero from metagpt.utils.common import tool2name ARCHITECT_INSTRUCTION = """ -Use WriteDesign tool to write a system design document if a system design is required; Use `write_trd_and_framework` tool to write a software framework if a software framework is required; +Use WriteDesign tool to write a system design document if a system design is required; Note: 1. When you think, just analyze which tool you should use, and then provide your answer. And your output should contain firstly, secondly, ... diff --git a/metagpt/roles/di/role_zero.py b/metagpt/roles/di/role_zero.py index 12d46783c..d49aa702a 100644 --- a/metagpt/roles/di/role_zero.py +++ b/metagpt/roles/di/role_zero.py @@ -302,6 +302,10 @@ class RoleZero(Role): # If the answer contains the substring '[Message] from A to B:', remove it. pattern = r"\[Message\] from .+? to .+?:\s*" answer = re.sub(pattern, "", answer, count=1) + if "command_name" in answer: + # an actual TASK intent misclassified as QUICK, correct it here, FIXME: a better way is to classify it correctly in the first place + answer = "" + intent_result = "TASK" elif "SEARCH" in intent_result: query = "\n".join(str(msg) for msg in memory) answer = await SearchEnhancedQA().run(query) diff --git a/metagpt/tools/libs/browser.py b/metagpt/tools/libs/browser.py index bba7fa5a8..f5aff553e 100644 --- a/metagpt/tools/libs/browser.py +++ b/metagpt/tools/libs/browser.py @@ -69,14 +69,14 @@ class Browser(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) - playwright: Optional[Playwright] = None - browser_instance: Optional[Browser_] = None - browser_ctx: Optional[BrowserContext] = None - page: Optional[Page] = None + playwright: Optional[Playwright] = Field(default=None, exclude=True) + browser_instance: Optional[Browser_] = Field(default=None, exclude=True) + browser_ctx: Optional[BrowserContext] = Field(default=None, exclude=True) + page: Optional[Page] = Field(default=None, exclude=True) accessibility_tree: list = Field(default_factory=list) - headless: bool = True + headless: bool = Field(default=True) proxy: Optional[dict] = Field(default_factory=get_proxy_from_env) - is_empty_page: bool = True + is_empty_page: bool = Field(default=True) reporter: BrowserReporter = Field(default_factory=BrowserReporter) async def start(self) -> None: diff --git a/metagpt/utils/__init__.py b/metagpt/utils/__init__.py index f13175cf8..26042eb0e 100644 --- a/metagpt/utils/__init__.py +++ b/metagpt/utils/__init__.py @@ -19,6 +19,7 @@ __all__ = [ "read_docx", "Singleton", "TOKEN_COSTS", + "new_transaction_id", "count_message_tokens", "count_string_tokens", ] diff --git a/metagpt/utils/a11y_tree.py b/metagpt/utils/a11y_tree.py index 59acbc6dc..133c4f63a 100644 --- a/metagpt/utils/a11y_tree.py +++ b/metagpt/utils/a11y_tree.py @@ -111,6 +111,12 @@ async def click_element(page: Page, backend_node_id: int): resp = await get_bounding_rect(cdp_session, backend_node_id) node_info = resp["result"]["value"] x, y = await get_element_center(node_info) + # Move to the location of the element + await page.evaluate(f"window.scrollTo({x}- window.innerWidth/2,{y} - window.innerHeight/2);") + # Refresh the relative location of the element + resp = await get_bounding_rect(cdp_session, backend_node_id) + node_info = resp["result"]["value"] + x, y = await get_element_center(node_info) await page.mouse.click(x, y) diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py index 6ca9b37a2..2b2a209be 100644 --- a/metagpt/utils/common.py +++ b/metagpt/utils/common.py @@ -26,6 +26,7 @@ import re import sys import time import traceback +import uuid from asyncio import iscoroutinefunction from datetime import datetime from functools import partial @@ -1089,6 +1090,19 @@ def tool2name(cls, methods: List[str], entry) -> Dict[str, Any]: return mappings +def new_transaction_id(postfix_len=8) -> str: + """ + Generates a new unique transaction ID based on current timestamp and a random UUID. + + Args: + postfix_len (int): Length of the random UUID postfix to include in the transaction ID. Default is 8. + + Returns: + str: A unique transaction ID composed of timestamp and a random UUID. + """ + return datetime.now().strftime("%Y%m%d%H%M%ST") + uuid.uuid4().hex[0:postfix_len] + + def log_time(method): """A time-consuming decorator for printing execution duration.""" diff --git a/metagpt/utils/mermaid.py b/metagpt/utils/mermaid.py index d87ae4f83..64996717e 100644 --- a/metagpt/utils/mermaid.py +++ b/metagpt/utils/mermaid.py @@ -8,6 +8,7 @@ import asyncio import os from pathlib import Path +from typing import List, Optional from metagpt.config2 import Config from metagpt.logs import logger @@ -15,16 +16,29 @@ from metagpt.utils.common import awrite, check_cmd_exists async def mermaid_to_file( - engine, mermaid_code, output_file_without_suffix, width=2048, height=2048, config=None + engine, + mermaid_code, + output_file_without_suffix, + width=2048, + height=2048, + config=None, + suffixes: Optional[List[str]] = None, ) -> int: - """suffix: png/svg/pdf + """Convert Mermaid code to various file formats. - :param mermaid_code: mermaid code - :param output_file_without_suffix: output filename - :param width: - :param height: - :return: 0 if succeed, -1 if failed + Args: + engine (str): The engine to use for conversion. Supported engines are "nodejs", "playwright", "pyppeteer", "ink", and "none". + mermaid_code (str): The Mermaid code to be converted. + output_file_without_suffix (str): The output file name without the suffix. + width (int, optional): The width of the output image. Defaults to 2048. + height (int, optional): The height of the output image. Defaults to 2048. + config (Optional[Config], optional): The configuration to use for the conversion. Defaults to None, which uses the default configuration. + suffixes (Optional[List[str]], optional): The file suffixes to generate. Supports "png", "pdf", and "svg". Defaults to ["png"]. + + Returns: + int: 0 if the conversion is successful, -1 if the conversion fails. """ + suffixes = suffixes or ["png"] # Write the Mermaid code to a temporary file config = config if config else Config.default() dir_name = os.path.dirname(output_file_without_suffix) @@ -41,7 +55,7 @@ async def mermaid_to_file( ) return -1 - for suffix in ["pdf", "svg", "png"]: + for suffix in suffixes: output_file = f"{output_file_without_suffix}.{suffix}" # Call the `mmdc` command to convert the Mermaid code to a PNG logger.info(f"Generating {output_file}..") @@ -75,15 +89,15 @@ async def mermaid_to_file( if engine == "playwright": from metagpt.utils.mmdc_playwright import mermaid_to_file - return await mermaid_to_file(mermaid_code, output_file_without_suffix, width, height) + return await mermaid_to_file(mermaid_code, output_file_without_suffix, width, height, suffixes=suffixes) elif engine == "pyppeteer": from metagpt.utils.mmdc_pyppeteer import mermaid_to_file - return await mermaid_to_file(mermaid_code, output_file_without_suffix, width, height) + return await mermaid_to_file(mermaid_code, output_file_without_suffix, width, height, suffixes=suffixes) elif engine == "ink": from metagpt.utils.mmdc_ink import mermaid_to_file - return await mermaid_to_file(mermaid_code, output_file_without_suffix) + return await mermaid_to_file(mermaid_code, output_file_without_suffix, suffixes=suffixes) elif engine == "none": return 0 else: diff --git a/metagpt/utils/mmdc_ink.py b/metagpt/utils/mmdc_ink.py index d594adb30..15d6d6083 100644 --- a/metagpt/utils/mmdc_ink.py +++ b/metagpt/utils/mmdc_ink.py @@ -6,21 +6,29 @@ @File : mermaid.py """ import base64 +from typing import List, Optional from aiohttp import ClientError, ClientSession from metagpt.logs import logger -async def mermaid_to_file(mermaid_code, output_file_without_suffix): - """suffix: png/svg - :param mermaid_code: mermaid code - :param output_file_without_suffix: output filename without suffix - :return: 0 if succeed, -1 if failed +async def mermaid_to_file(mermaid_code, output_file_without_suffix, suffixes: Optional[List[str]] = None): + """Convert Mermaid code to various file formats. + + Args: + mermaid_code (str): The Mermaid code to be converted. + output_file_without_suffix (str): The output file name without the suffix. + width (int, optional): The width of the output image. Defaults to 2048. + height (int, optional): The height of the output image. Defaults to 2048. + suffixes (Optional[List[str]], optional): The file suffixes to generate. Supports "png", "pdf", and "svg". Defaults to ["png"]. + + Returns: + int: 0 if the conversion is successful, -1 if the conversion fails. """ encoded_string = base64.b64encode(mermaid_code.encode()).decode() - - for suffix in ["svg", "png"]: + suffixes = suffixes or ["png"] + for suffix in suffixes: output_file = f"{output_file_without_suffix}.{suffix}" path_type = "svg" if suffix == "svg" else "img" url = f"https://mermaid.ink/{path_type}/{encoded_string}" diff --git a/metagpt/utils/mmdc_playwright.py b/metagpt/utils/mmdc_playwright.py index 5d455e1c5..cf846a7e9 100644 --- a/metagpt/utils/mmdc_playwright.py +++ b/metagpt/utils/mmdc_playwright.py @@ -7,6 +7,7 @@ """ import os +from typing import List, Optional from urllib.parse import urljoin from playwright.async_api import async_playwright @@ -14,20 +15,22 @@ from playwright.async_api import async_playwright from metagpt.logs import logger -async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048, height=2048) -> int: - """ - Converts the given Mermaid code to various output formats and saves them to files. +async def mermaid_to_file( + mermaid_code, output_file_without_suffix, width=2048, height=2048, suffixes: Optional[List[str]] = None +) -> int: + """Convert Mermaid code to various file formats. Args: - mermaid_code (str): The Mermaid code to convert. - output_file_without_suffix (str): The output file name without the file extension. - width (int, optional): The width of the output image in pixels. Defaults to 2048. - height (int, optional): The height of the output image in pixels. Defaults to 2048. + mermaid_code (str): The Mermaid code to be converted. + output_file_without_suffix (str): The output file name without the suffix. + width (int, optional): The width of the output image. Defaults to 2048. + height (int, optional): The height of the output image. Defaults to 2048. + suffixes (Optional[List[str]], optional): The file suffixes to generate. Supports "png", "pdf", and "svg". Defaults to ["png"]. Returns: - int: Returns 1 if the conversion and saving were successful, -1 otherwise. + int: 0 if the conversion is successful, -1 if the conversion fails. """ - suffixes = ["png", "svg", "pdf"] + suffixes = suffixes or ["png"] __dirname = os.path.dirname(os.path.abspath(__file__)) async with async_playwright() as p: diff --git a/metagpt/utils/mmdc_pyppeteer.py b/metagpt/utils/mmdc_pyppeteer.py index 4e30ee538..36b77b5b2 100644 --- a/metagpt/utils/mmdc_pyppeteer.py +++ b/metagpt/utils/mmdc_pyppeteer.py @@ -6,6 +6,7 @@ @File : mmdc_pyppeteer.py """ import os +from typing import List, Optional from urllib.parse import urljoin from pyppeteer import launch @@ -14,21 +15,24 @@ from metagpt.config2 import Config from metagpt.logs import logger -async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048, height=2048, config=None) -> int: - """ - Converts the given Mermaid code to various output formats and saves them to files. +async def mermaid_to_file( + mermaid_code, output_file_without_suffix, width=2048, height=2048, config=None, suffixes: Optional[List[str]] = None +) -> int: + """Convert Mermaid code to various file formats. Args: - mermaid_code (str): The Mermaid code to convert. - output_file_without_suffix (str): The output file name without the file extension. - width (int, optional): The width of the output image in pixels. Defaults to 2048. - height (int, optional): The height of the output image in pixels. Defaults to 2048. + mermaid_code (str): The Mermaid code to be converted. + output_file_without_suffix (str): The output file name without the suffix. + width (int, optional): The width of the output image. Defaults to 2048. + height (int, optional): The height of the output image. Defaults to 2048. + config (Optional[Config], optional): The configuration to use for the conversion. Defaults to None, which uses the default configuration. + suffixes (Optional[List[str]], optional): The file suffixes to generate. Supports "png", "pdf", and "svg". Defaults to ["png"]. Returns: - int: Returns 1 if the conversion and saving were successful, -1 otherwise. + int: 0 if the conversion is successful, -1 if the conversion fails. """ config = config if config else Config.default() - suffixes = ["png", "svg", "pdf"] + suffixes = suffixes or ["png"] __dirname = os.path.dirname(os.path.abspath(__file__)) if config.mermaid.pyppeteer_path: diff --git a/tests/metagpt/exp_pool/test_manager.py b/tests/metagpt/exp_pool/test_manager.py index 4d298a44e..b0e4e8537 100644 --- a/tests/metagpt/exp_pool/test_manager.py +++ b/tests/metagpt/exp_pool/test_manager.py @@ -1,26 +1,31 @@ import pytest from metagpt.config2 import Config -from metagpt.configs.exp_pool_config import ExperiencePoolConfig +from metagpt.configs.exp_pool_config import ( + ExperiencePoolConfig, + ExperiencePoolRetrievalType, +) from metagpt.configs.llm_config import LLMConfig from metagpt.exp_pool.manager import Experience, ExperienceManager -from metagpt.exp_pool.schema import QueryType +from metagpt.exp_pool.schema import DEFAULT_SIMILARITY_TOP_K, QueryType class TestExperienceManager: @pytest.fixture def mock_config(self): - return Config(llm=LLMConfig(), exp_pool=ExperiencePoolConfig(enable_write=True, enable_read=True, enabled=True)) + return Config( + llm=LLMConfig(), + exp_pool=ExperiencePoolConfig( + enable_write=True, enable_read=True, enabled=True, retrieval_type=ExperiencePoolRetrievalType.BM25 + ), + ) @pytest.fixture def mock_storage(self, mocker): engine = mocker.MagicMock() engine.add_objs = mocker.MagicMock() engine.aretrieve = mocker.AsyncMock(return_value=[]) - engine._retriever = mocker.MagicMock() - engine._retriever._vector_store = mocker.MagicMock() - engine._retriever._vector_store._collection = mocker.MagicMock() - engine._retriever._vector_store._collection.count = mocker.MagicMock(return_value=10) + engine.count = mocker.MagicMock(return_value=10) return engine @pytest.fixture @@ -29,8 +34,33 @@ class TestExperienceManager: manager._storage = mock_storage return manager - def test_vector_store_property(self, exp_manager): - assert exp_manager.vector_store == exp_manager.storage._retriever._vector_store + def test_storage_property(self, exp_manager, mock_storage): + assert exp_manager.storage == mock_storage + + def test_storage_property_initialization(self, mocker, mock_config): + mocker.patch.object(ExperienceManager, "_resolve_storage", return_value=mocker.MagicMock()) + manager = ExperienceManager(config=mock_config) + assert manager._storage is None + _ = manager.storage + assert manager._storage is not None + + def test_create_exp_write_disabled(self, exp_manager, mock_config): + mock_config.exp_pool.enable_write = False + exp = Experience(req="test", resp="response") + exp_manager.create_exp(exp) + exp_manager.storage.add_objs.assert_not_called() + + def test_create_exp_write_enabled(self, exp_manager): + exp = Experience(req="test", resp="response") + exp_manager.create_exp(exp) + exp_manager.storage.add_objs.assert_called_once_with([exp]) + exp_manager.storage.persist.assert_called_once_with(exp_manager.config.exp_pool.persist_path) + + @pytest.mark.asyncio + async def test_query_exps_read_disabled(self, exp_manager, mock_config): + mock_config.exp_pool.enable_read = False + result = await exp_manager.query_exps("query") + assert result == [] @pytest.mark.asyncio async def test_query_exps_with_exact_match(self, exp_manager, mocker): @@ -65,14 +95,50 @@ class TestExperienceManager: def test_get_exps_count(self, exp_manager): assert exp_manager.get_exps_count() == 10 - def test_create_exp_write_disabled(self, exp_manager, mock_config): - mock_config.exp_pool.enable_write = False - exp = Experience(req="test", resp="response") - exp_manager.create_exp(exp) - exp_manager.storage.add_objs.assert_not_called() + def test_resolve_storage_bm25(self, mocker, mock_config): + mock_config.exp_pool.retrieval_type = ExperiencePoolRetrievalType.BM25 + mocker.patch.object(ExperienceManager, "_create_bm25_storage", return_value=mocker.MagicMock()) + manager = ExperienceManager(config=mock_config) + storage = manager._resolve_storage() + manager._create_bm25_storage.assert_called_once() + assert storage is not None - @pytest.mark.asyncio - async def test_query_exps_read_disabled(self, exp_manager, mock_config): - mock_config.exp_pool.enable_read = False - result = await exp_manager.query_exps("query") - assert result == [] + def test_resolve_storage_chroma(self, mocker, mock_config): + mock_config.exp_pool.retrieval_type = ExperiencePoolRetrievalType.CHROMA + mocker.patch.object(ExperienceManager, "_create_chroma_storage", return_value=mocker.MagicMock()) + manager = ExperienceManager(config=mock_config) + storage = manager._resolve_storage() + manager._create_chroma_storage.assert_called_once() + assert storage is not None + + def test_create_bm25_storage(self, mocker, mock_config): + mocker.patch("metagpt.rag.engines.SimpleEngine.from_objs", return_value=mocker.MagicMock()) + mocker.patch("metagpt.rag.engines.SimpleEngine.from_index", return_value=mocker.MagicMock()) + mocker.patch("metagpt.rag.engines.SimpleEngine.get_obj_nodes", return_value=[]) + mocker.patch("metagpt.rag.engines.SimpleEngine._resolve_embed_model", return_value=mocker.MagicMock()) + mocker.patch("llama_index.core.VectorStoreIndex", return_value=mocker.MagicMock()) + mocker.patch("metagpt.rag.schema.BM25RetrieverConfig", return_value=mocker.MagicMock()) + mocker.patch("pathlib.Path.exists", return_value=False) + + manager = ExperienceManager(config=mock_config) + storage = manager._create_bm25_storage() + assert storage is not None + + def test_create_chroma_storage(self, mocker, mock_config): + mocker.patch("metagpt.rag.engines.SimpleEngine.from_objs", return_value=mocker.MagicMock()) + manager = ExperienceManager(config=mock_config) + storage = manager._create_chroma_storage() + assert storage is not None + + def test_get_ranker_configs_use_llm_ranker_true(self, mock_config): + mock_config.exp_pool.use_llm_ranker = True + manager = ExperienceManager(config=mock_config) + ranker_configs = manager._get_ranker_configs() + assert len(ranker_configs) == 1 + assert ranker_configs[0].top_n == DEFAULT_SIMILARITY_TOP_K + + def test_get_ranker_configs_use_llm_ranker_false(self, mock_config): + mock_config.exp_pool.use_llm_ranker = False + manager = ExperienceManager(config=mock_config) + ranker_configs = manager._get_ranker_configs() + assert len(ranker_configs) == 0 diff --git a/tests/metagpt/rag/engines/test_simple.py b/tests/metagpt/rag/engines/test_simple.py index 8c7a15be2..e0a174ed2 100644 --- a/tests/metagpt/rag/engines/test_simple.py +++ b/tests/metagpt/rag/engines/test_simple.py @@ -75,7 +75,7 @@ class TestSimpleEngine: ) # Assert - mock_simple_directory_reader.assert_called_once_with(input_dir=input_dir, input_files=input_files) + mock_simple_directory_reader.assert_called_once_with(input_dir=input_dir, input_files=input_files, fs=None) mock_get_retriever.assert_called_once() mock_get_rankers.assert_called_once() mock_get_response_synthesizer.assert_called_once_with(llm=llm) diff --git a/tests/metagpt/rag/factories/test_embedding.py b/tests/metagpt/rag/factories/test_embedding.py index 1a9e9b2c9..03bdfab1d 100644 --- a/tests/metagpt/rag/factories/test_embedding.py +++ b/tests/metagpt/rag/factories/test_embedding.py @@ -1,5 +1,6 @@ import pytest +from metagpt.config2 import Config from metagpt.configs.embedding_config import EmbeddingType from metagpt.configs.llm_config import LLMType from metagpt.rag.factories.embedding import RAGEmbeddingFactory @@ -12,7 +13,10 @@ class TestRAGEmbeddingFactory: @pytest.fixture def mock_config(self, mocker): - return mocker.patch("metagpt.rag.factories.embedding.config") + config = Config.default().model_copy(deep=True) + default = mocker.patch("metagpt.config2.Config.default") + default.return_value = config + return config @staticmethod def mock_openai_embedding(mocker): diff --git a/tests/metagpt/utils/test_mermaid.py b/tests/metagpt/utils/test_mermaid.py index 7367463dc..1fbf060fe 100644 --- a/tests/metagpt/utils/test_mermaid.py +++ b/tests/metagpt/utils/test_mermaid.py @@ -8,28 +8,32 @@ import pytest -from metagpt.utils.common import check_cmd_exists +from metagpt.const import DEFAULT_WORKSPACE_ROOT +from metagpt.utils.common import check_cmd_exists, new_transaction_id from metagpt.utils.mermaid import MMC1, mermaid_to_file @pytest.mark.asyncio -@pytest.mark.parametrize("engine", ["nodejs", "ink"]) # TODO: playwright and pyppeteer -async def test_mermaid(engine, context, mermaid_mocker): +@pytest.mark.parametrize( + ("engine", "suffixes"), [("nodejs", None), ("nodejs", ["png", "svg", "pdf"]), ("ink", None)] +) # TODO: playwright and pyppeteer +async def test_mermaid(engine, suffixes, context, mermaid_mocker): # nodejs prerequisites: npm install -g @mermaid-js/mermaid-cli # ink prerequisites: connected to internet # playwright prerequisites: playwright install --with-deps chromium assert check_cmd_exists("npm") == 0 - save_to = context.git_repo.workdir / f"{engine}/1" - await mermaid_to_file(engine, MMC1, save_to) + save_to = DEFAULT_WORKSPACE_ROOT / f"{new_transaction_id()}/{engine}/1" + await mermaid_to_file(engine, MMC1, save_to, suffixes=suffixes) # ink does not support pdf + exts = ["." + i for i in suffixes] if suffixes else [".png"] if engine == "ink": - for ext in [".svg", ".png"]: + for ext in exts: assert save_to.with_suffix(ext).exists() save_to.with_suffix(ext).unlink(missing_ok=True) else: - for ext in [".pdf", ".svg", ".png"]: + for ext in exts: assert save_to.with_suffix(ext).exists() save_to.with_suffix(ext).unlink(missing_ok=True)