Merge branch 'mgx_ops' into fibug/quick_think_output_message_tag

This commit is contained in:
黄伟韬 2024-08-20 18:10:06 +08:00
commit e690488b08
29 changed files with 390 additions and 110 deletions

View file

@ -79,6 +79,8 @@ exp_pool:
enable_read: false
enable_write: false
persist_path: .chroma_exp_data # The directory.
retrieval_type: bm25 # Default is `bm25`, can be set to `chroma` for vector storage, which requires setting up embedding.
use_llm_ranker: false # If `use_llm_ranker` is true, then it will use LLM Reranker to get better result, but it is not always guaranteed that the output will be parseable for reranking.
azure_tts_subscription_key: "YOUR_SUBSCRIPTION_KEY"
azure_tts_region: "eastus"

View file

@ -3,7 +3,7 @@ # Experience Pool
## Prerequisites
- Ensure the RAG module is installed: https://docs.deepwisdom.ai/main/en/guide/in_depth_guides/rag_module.html
- Set embedding: https://docs.deepwisdom.ai/main/en/guide/in_depth_guides/rag_module.html
- Set both `enable_read` and `enable_write` to `true` in the `exp_pool` section of `config2.yaml`
- Set `enabled``enable_read` and `enable_write` to `true` in the `exp_pool` section of `config2.yaml`
## Example Files

View file

@ -24,7 +24,7 @@ Requirements:
创建一个贪吃蛇只需要给出设计文档和代码
Outputs:
[User Restrictions] : 只需要给出设计文档和代码.
[Language Restrictions] : The response, message and instruction must be in the language specified by Chinese.
[Language Restrictions] : The response, message and instruction must be in Chinese.
[Programming Language] : HTML (*.html), CSS (*.css), and JavaScript (*.js)
Example 2
@ -32,7 +32,7 @@ Requirements:
Create 2048 game using Python. Do not write PRD.
Outputs:
[User Restrictions] : Do not write PRD.
[Language Restrictions] : The response, message and instruction must be in the language specified by English.
[Language Restrictions] : The response, message and instruction must be in English.
[Programming Language] : Python
Example 3
@ -40,7 +40,7 @@ Requirements:
You must ignore create PRD and TRD. Help me write a schedule display program for the Paris Olympics.
Outputs:
[User Restrictions] : You must ignore create PRD and TRD.
[Language Restrictions] : The response, message and instruction must be in the language specified by English.
[Language Restrictions] : The response, message and instruction must be in English.
[Programming Language] : HTML (*.html), CSS (*.css), and JavaScript (*.js)
"""
@ -57,7 +57,7 @@ Note:
OUTPUT_FORMAT = """
[User Restrictions] : the restrictions in the requirements
[Language Restrictions] : The response, message and instruction must be in the language specified by {{language}}
[Language Restrictions] : The response, message and instruction must be in {{language}}
[Programming Language] : Your program must use ...
"""

View file

@ -245,6 +245,7 @@ class WriteDesign(Action):
) -> str:
prd_content = ""
if prd_filename:
prd_filename = rectify_pathname(path=prd_filename, default_filename="prd.json")
prd_content = await aread(filename=prd_filename)
context = "### User Requirements\n{user_requirement}\n### Extra_info\n{extra_info}\n### PRD\n{prd}\n".format(
user_requirement=to_markdown_code_block(user_requirement),

View file

@ -180,6 +180,7 @@ class WriteTasks(Action):
) -> str:
context = to_markdown_code_block(user_requirement)
if design_filename:
design_filename = rectify_pathname(path=design_filename, default_filename="system_design.json")
content = await aread(filename=design_filename)
context += to_markdown_code_block(content)

View file

@ -1,8 +1,15 @@
from enum import Enum
from pydantic import Field
from metagpt.utils.yaml_model import YamlModel
class ExperiencePoolRetrievalType(Enum):
BM25 = "bm25"
CHROMA = "chroma"
class ExperiencePoolConfig(YamlModel):
enabled: bool = Field(
default=False,
@ -11,3 +18,7 @@ class ExperiencePoolConfig(YamlModel):
enable_read: bool = Field(default=False, description="Enable to read from experience pool.")
enable_write: bool = Field(default=False, description="Enable to write to experience pool.")
persist_path: str = Field(default=".chroma_exp_data", description="The persist path for experience pool.")
retrieval_type: ExperiencePoolRetrievalType = Field(
default=ExperiencePoolRetrievalType.BM25, description="The retrieval type for experience pool."
)
use_llm_ranker: bool = Field(default=False, description="Use LLM Reranker to get better result.")

View file

@ -134,14 +134,14 @@ class ExpCacheHandler(BaseModel):
"""Fetch experiences by query_type."""
self._exps = await self.exp_manager.query_exps(self._req, query_type=self.query_type, tag=self.tag)
logger.debug(f"Found {len(self._exps)} experiences for req '{self._req[:20]}...' and tag '{self.tag}'")
logger.info(f"Found {len(self._exps)} experiences for tag '{self.tag}'")
async def get_one_perfect_exp(self) -> Optional[Any]:
"""Get a potentially perfect experience, and resolve resp."""
for exp in self._exps:
if await self.exp_perfect_judge.is_perfect_exp(exp, self._req, *self.args, **self.kwargs):
logger.debug(f"Got one perfect experience for req '{exp.req[:20]}...'")
logger.info(f"Got one perfect experience for req '{exp.req[:20]}...'")
return self.serializer.deserialize_resp(exp.resp)
return None

View file

@ -1,10 +1,12 @@
"""Experience Manager."""
from pathlib import Path
from typing import TYPE_CHECKING, Any
from pydantic import BaseModel, ConfigDict, Field
from metagpt.config2 import Config
from metagpt.configs.exp_pool_config import ExperiencePoolRetrievalType
from metagpt.exp_pool.schema import (
DEFAULT_COLLECTION_NAME,
DEFAULT_SIMILARITY_TOP_K,
@ -15,7 +17,7 @@ from metagpt.logs import logger
from metagpt.utils.exceptions import handle_exception
if TYPE_CHECKING:
from llama_index.vector_stores.chroma import ChromaVectorStore
from metagpt.rag.engines import SimpleEngine
class ExperienceManager(BaseModel):
@ -32,40 +34,16 @@ class ExperienceManager(BaseModel):
config: Config = Field(default_factory=Config.default)
_storage: Any = None
_vector_store: Any = None
@property
def storage(self):
if self._storage is None:
try:
from metagpt.rag.engines import SimpleEngine
from metagpt.rag.schema import ChromaRetrieverConfig, LLMRankerConfig
except ImportError:
raise ImportError("To use the experience pool, you need to install the rag module.")
retriever_configs = [
ChromaRetrieverConfig(
persist_path=self.config.exp_pool.persist_path,
collection_name=DEFAULT_COLLECTION_NAME,
similarity_top_k=DEFAULT_SIMILARITY_TOP_K,
)
]
ranker_configs = [LLMRankerConfig(top_n=DEFAULT_SIMILARITY_TOP_K)]
self._storage: SimpleEngine = SimpleEngine.from_objs(
retriever_configs=retriever_configs, ranker_configs=ranker_configs
)
logger.info(f"exp_pool config: {self.config.exp_pool}")
self._storage = self._resolve_storage()
return self._storage
@property
def vector_store(self):
if not self._vector_store:
self._vector_store: ChromaVectorStore = self.storage._retriever._vector_store
return self._vector_store
@handle_exception
def create_exp(self, exp: Experience):
"""Adds an experience to the storage if writing is enabled.
@ -78,6 +56,7 @@ class ExperienceManager(BaseModel):
return
self.storage.add_objs([exp])
self.storage.persist(self.config.exp_pool.persist_path)
@handle_exception(default_return=[])
async def query_exps(self, req: str, tag: str = "", query_type: QueryType = QueryType.SEMANTIC) -> list[Experience]:
@ -110,7 +89,106 @@ class ExperienceManager(BaseModel):
def get_exps_count(self) -> int:
"""Get the total number of experiences."""
return self.vector_store._collection.count()
return self.storage.count()
def _resolve_storage(self) -> "SimpleEngine":
"""Selects the appropriate storage creation method based on the configured retrieval type."""
storage_creators = {
ExperiencePoolRetrievalType.BM25: self._create_bm25_storage,
ExperiencePoolRetrievalType.CHROMA: self._create_chroma_storage,
}
return storage_creators[self.config.exp_pool.retrieval_type]()
def _create_bm25_storage(self) -> "SimpleEngine":
"""Creates or loads BM25 storage.
This function attempts to create a new BM25 storage if the specified
document store path does not exist. If the path exists, it loads the
existing BM25 storage.
Returns:
SimpleEngine: An instance of SimpleEngine configured with BM25 storage.
Raises:
ImportError: If required modules are not installed.
"""
try:
from metagpt.rag.engines import SimpleEngine
from metagpt.rag.schema import BM25IndexConfig, BM25RetrieverConfig
except ImportError:
raise ImportError("To use the experience pool, you need to install the rag module.")
persist_path = Path(self.config.exp_pool.persist_path)
docstore_path = persist_path / "docstore.json"
ranker_configs = self._get_ranker_configs()
if not docstore_path.exists():
logger.debug(f"Path `{docstore_path}` not exists, try to create a new bm25 storage.")
exps = [Experience(req="req", resp="resp")]
retriever_configs = [BM25RetrieverConfig(create_index=True, similarity_top_k=DEFAULT_SIMILARITY_TOP_K)]
storage = SimpleEngine.from_objs(
objs=exps, retriever_configs=retriever_configs, ranker_configs=ranker_configs
)
return storage
logger.debug(f"Path `{docstore_path}` exists, try to load bm25 storage.")
retriever_configs = [BM25RetrieverConfig(similarity_top_k=DEFAULT_SIMILARITY_TOP_K)]
storage = SimpleEngine.from_index(
BM25IndexConfig(persist_path=persist_path),
retriever_configs=retriever_configs,
ranker_configs=ranker_configs,
)
return storage
def _create_chroma_storage(self) -> "SimpleEngine":
"""Creates Chroma storage.
Returns:
SimpleEngine: An instance of SimpleEngine configured with Chroma storage.
Raises:
ImportError: If required modules are not installed.
"""
try:
from metagpt.rag.engines import SimpleEngine
from metagpt.rag.schema import ChromaRetrieverConfig
except ImportError:
raise ImportError("To use the experience pool, you need to install the rag module.")
retriever_configs = [
ChromaRetrieverConfig(
persist_path=self.config.exp_pool.persist_path,
collection_name=DEFAULT_COLLECTION_NAME,
similarity_top_k=DEFAULT_SIMILARITY_TOP_K,
)
]
ranker_configs = self._get_ranker_configs()
storage = SimpleEngine.from_objs(retriever_configs=retriever_configs, ranker_configs=ranker_configs)
return storage
def _get_ranker_configs(self):
"""Returns ranker configurations based on the configuration.
If `use_llm_ranker` is True, returns a list with one `LLMRankerConfig`
instance. Otherwise, returns an empty list.
Returns:
list: A list of `LLMRankerConfig` instances or an empty list.
"""
from metagpt.rag.schema import LLMRankerConfig
return [LLMRankerConfig(top_n=DEFAULT_SIMILARITY_TOP_K)] if self.config.exp_pool.use_llm_ranker else []
_exp_manager = None

View file

@ -16,7 +16,8 @@ Note:
1. If the requirement is a pure DATA-RELATED requirement, such as web browsing, web scraping, web searching, web imitation, data science, data analysis, machine learning, deep learning, text-to-image etc. DON'T decompose it, assign a single task with the original user requirement as instruction directly to Data Analyst.
2. If the requirement is developing a software, game, app, or website, excluding the above data-related tasks, you should decompose the requirement into multiple tasks and assign them to different team members based on their expertise. The software default development process has four steps: creating a Product Requirement Document (PRD) by the Product Manager -> writing a System Design by the Architect -> creating tasks by the Project Manager -> and coding by the Engineer. You may choose to execute any of these steps. When publishing message to Product Manager, you should directly copy the full original user requirement.
2.1. If the requirement contains both DATA-RELATED part mentioned in 1 and software development part mentioned in 2, you should decompose the software development part and assign them to different team members based on their expertise, and assign the DATA-RELATED part to Data Analyst David directly.
3. If the requirement is to fix a bug or issue, you should assign it to Issue Solver instead of Engineer. However, if the bug or issue is related to the software developed by the team, you should assign it to Engineer.
3.1 If the task involves code review or code checking, you should assign it to Engineer.
3.2. If the requirement is to fix a bug or issue, you should assign it to Issue Solver. However, if the code is written by Engineer, Engineer must maintain the code.
4. If the requirement is a common-sense, logical, or math problem, you should respond directly without assigning any task to team members.
5. If you think the requirement is not clear or ambiguous, you should ask the user for clarification immediately. Assign tasks only after all info is clear.
6. It is helpful for Engineer to have both the system design and the project schedule for writing the code, so include paths of both files (if available) and remind Engineer to definitely read them when publishing message to Engineer.

View file

@ -37,7 +37,11 @@ from metagpt.rag.factories import (
get_retriever,
)
from metagpt.rag.interface import NoEmbedding, RAGObject
from metagpt.rag.retrievers.base import ModifiableRAGRetriever, PersistableRAGRetriever
from metagpt.rag.retrievers.base import (
ModifiableRAGRetriever,
PersistableRAGRetriever,
QueryableRAGRetriever,
)
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
from metagpt.rag.schema import (
BaseIndexConfig,
@ -144,7 +148,7 @@ class SimpleEngine(RetrieverQueryEngine):
if not objs and any(isinstance(config, BM25RetrieverConfig) for config in retriever_configs):
raise ValueError("In BM25RetrieverConfig, Objs must not be empty.")
nodes = [ObjectNode(text=obj.rag_key(), metadata=ObjectNode.get_obj_metadata(obj)) for obj in objs]
nodes = cls.get_obj_nodes(objs)
return cls._from_nodes(
nodes=nodes,
@ -201,7 +205,7 @@ class SimpleEngine(RetrieverQueryEngine):
"""Adds objects to the retriever, storing each object's original form in metadata for future reference."""
self._ensure_retriever_modifiable()
nodes = [ObjectNode(text=obj.rag_key(), metadata=ObjectNode.get_obj_metadata(obj)) for obj in objs]
nodes = self.get_obj_nodes(objs)
self._save_nodes(nodes)
def persist(self, persist_dir: Union[str, os.PathLike], **kwargs):
@ -210,6 +214,18 @@ class SimpleEngine(RetrieverQueryEngine):
self._persist(str(persist_dir), **kwargs)
def count(self) -> int:
"""Count."""
self._ensure_retriever_queryable()
return self._retriever.query_total_count()
@staticmethod
def get_obj_nodes(objs: Optional[list[RAGObject]] = None) -> list[ObjectNode]:
"""Converts a list of RAGObjects to a list of ObjectNodes."""
return [ObjectNode(text=obj.rag_key(), metadata=ObjectNode.get_obj_metadata(obj)) for obj in objs]
@classmethod
def _from_nodes(
cls,
@ -258,6 +274,9 @@ class SimpleEngine(RetrieverQueryEngine):
def _ensure_retriever_persistable(self):
self._ensure_retriever_of_type(PersistableRAGRetriever)
def _ensure_retriever_queryable(self):
self._ensure_retriever_of_type(QueryableRAGRetriever)
def _ensure_retriever_of_type(self, required_type: BaseRetriever):
"""Ensure that self.retriever is required_type, or at least one of its components, if it's a SimpleHybridRetriever.

View file

@ -7,6 +7,7 @@ import chromadb
import faiss
from llama_index.core import StorageContext, VectorStoreIndex
from llama_index.core.embeddings import BaseEmbedding
from llama_index.core.embeddings.mock_embed_model import MockEmbedding
from llama_index.core.schema import BaseNode
from llama_index.core.vector_stores.types import BasePydanticVectorStore
from llama_index.vector_stores.chroma import ChromaVectorStore
@ -85,6 +86,12 @@ class RetrieverFactory(ConfigBasedFactory):
index = self._extract_index(config, **kwargs)
nodes = list(index.docstore.docs.values()) if index else self._extract_nodes(config, **kwargs)
if index and not config.index:
config.index = index
if not config.index and config.create_index:
config.index = VectorStoreIndex(nodes, embed_model=MockEmbedding(embed_dim=1))
return DynamicBM25Retriever(nodes=nodes, **config.model_dump())
def _create_chroma_retriever(self, config: ChromaRetrieverConfig, **kwargs) -> ChromaRetriever:

View file

@ -45,3 +45,17 @@ class PersistableRAGRetriever(RAGRetriever):
@abstractmethod
def persist(self, persist_dir: str, **kwargs) -> None:
"""To support persist, must inplement this func"""
class QueryableRAGRetriever(RAGRetriever):
"""Support querying total count."""
@classmethod
def __subclasshook__(cls, C):
if cls is QueryableRAGRetriever:
return check_methods(C, "query_total_count")
return NotImplemented
@abstractmethod
def query_total_count(self) -> int:
"""To support querying total count, must implement this func"""

View file

@ -47,3 +47,8 @@ class DynamicBM25Retriever(BM25Retriever):
"""Support persist."""
if self._index:
self._index.storage_context.persist(persist_dir)
def query_total_count(self) -> int:
"""Support query total count."""
return len(self._nodes)

View file

@ -2,6 +2,7 @@
from llama_index.core.retrievers import VectorIndexRetriever
from llama_index.core.schema import BaseNode
from llama_index.vector_stores.chroma import ChromaVectorStore
class ChromaRetriever(VectorIndexRetriever):
@ -15,3 +16,10 @@ class ChromaRetriever(VectorIndexRetriever):
"""Support persist.
Chromadb automatically saves, so there is no need to implement."""
def query_total_count(self) -> int:
"""Support query total count."""
vector_store: ChromaVectorStore = self._vector_store
return vector_store._collection.count()

View file

@ -60,6 +60,11 @@ class FAISSRetrieverConfig(IndexRetrieverConfig):
class BM25RetrieverConfig(IndexRetrieverConfig):
"""Config for BM25-based retrievers."""
create_index: bool = Field(
default=False,
description="Indicates whether to create an index for the nodes. It is useful when you need to persist data while only using BM25.",
exclude=True,
)
_no_embedding: bool = PrivateAttr(default=True)

View file

@ -11,7 +11,7 @@ from metagpt.roles.di.role_zero import RoleZero
from metagpt.utils.common import tool2name
ARCHITECT_INSTRUCTION = """
Use WriteDesign tool to write a system design document if a system design is required; Use `write_trd_and_framework` tool to write a software framework if a software framework is required;
Use WriteDesign tool to write a system design document if a system design is required;
Note:
1. When you think, just analyze which tool you should use, and then provide your answer. And your output should contain firstly, secondly, ...

View file

@ -302,6 +302,10 @@ class RoleZero(Role):
# If the answer contains the substring '[Message] from A to B:', remove it.
pattern = r"\[Message\] from .+? to .+?:\s*"
answer = re.sub(pattern, "", answer, count=1)
if "command_name" in answer:
# an actual TASK intent misclassified as QUICK, correct it here, FIXME: a better way is to classify it correctly in the first place
answer = ""
intent_result = "TASK"
elif "SEARCH" in intent_result:
query = "\n".join(str(msg) for msg in memory)
answer = await SearchEnhancedQA().run(query)

View file

@ -69,14 +69,14 @@ class Browser(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
playwright: Optional[Playwright] = None
browser_instance: Optional[Browser_] = None
browser_ctx: Optional[BrowserContext] = None
page: Optional[Page] = None
playwright: Optional[Playwright] = Field(default=None, exclude=True)
browser_instance: Optional[Browser_] = Field(default=None, exclude=True)
browser_ctx: Optional[BrowserContext] = Field(default=None, exclude=True)
page: Optional[Page] = Field(default=None, exclude=True)
accessibility_tree: list = Field(default_factory=list)
headless: bool = True
headless: bool = Field(default=True)
proxy: Optional[dict] = Field(default_factory=get_proxy_from_env)
is_empty_page: bool = True
is_empty_page: bool = Field(default=True)
reporter: BrowserReporter = Field(default_factory=BrowserReporter)
async def start(self) -> None:

View file

@ -19,6 +19,7 @@ __all__ = [
"read_docx",
"Singleton",
"TOKEN_COSTS",
"new_transaction_id",
"count_message_tokens",
"count_string_tokens",
]

View file

@ -111,6 +111,12 @@ async def click_element(page: Page, backend_node_id: int):
resp = await get_bounding_rect(cdp_session, backend_node_id)
node_info = resp["result"]["value"]
x, y = await get_element_center(node_info)
# Move to the location of the element
await page.evaluate(f"window.scrollTo({x}- window.innerWidth/2,{y} - window.innerHeight/2);")
# Refresh the relative location of the element
resp = await get_bounding_rect(cdp_session, backend_node_id)
node_info = resp["result"]["value"]
x, y = await get_element_center(node_info)
await page.mouse.click(x, y)

View file

@ -26,6 +26,7 @@ import re
import sys
import time
import traceback
import uuid
from asyncio import iscoroutinefunction
from datetime import datetime
from functools import partial
@ -1089,6 +1090,19 @@ def tool2name(cls, methods: List[str], entry) -> Dict[str, Any]:
return mappings
def new_transaction_id(postfix_len=8) -> str:
"""
Generates a new unique transaction ID based on current timestamp and a random UUID.
Args:
postfix_len (int): Length of the random UUID postfix to include in the transaction ID. Default is 8.
Returns:
str: A unique transaction ID composed of timestamp and a random UUID.
"""
return datetime.now().strftime("%Y%m%d%H%M%ST") + uuid.uuid4().hex[0:postfix_len]
def log_time(method):
"""A time-consuming decorator for printing execution duration."""

View file

@ -8,6 +8,7 @@
import asyncio
import os
from pathlib import Path
from typing import List, Optional
from metagpt.config2 import Config
from metagpt.logs import logger
@ -15,16 +16,29 @@ from metagpt.utils.common import awrite, check_cmd_exists
async def mermaid_to_file(
engine, mermaid_code, output_file_without_suffix, width=2048, height=2048, config=None
engine,
mermaid_code,
output_file_without_suffix,
width=2048,
height=2048,
config=None,
suffixes: Optional[List[str]] = None,
) -> int:
"""suffix: png/svg/pdf
"""Convert Mermaid code to various file formats.
:param mermaid_code: mermaid code
:param output_file_without_suffix: output filename
:param width:
:param height:
:return: 0 if succeed, -1 if failed
Args:
engine (str): The engine to use for conversion. Supported engines are "nodejs", "playwright", "pyppeteer", "ink", and "none".
mermaid_code (str): The Mermaid code to be converted.
output_file_without_suffix (str): The output file name without the suffix.
width (int, optional): The width of the output image. Defaults to 2048.
height (int, optional): The height of the output image. Defaults to 2048.
config (Optional[Config], optional): The configuration to use for the conversion. Defaults to None, which uses the default configuration.
suffixes (Optional[List[str]], optional): The file suffixes to generate. Supports "png", "pdf", and "svg". Defaults to ["png"].
Returns:
int: 0 if the conversion is successful, -1 if the conversion fails.
"""
suffixes = suffixes or ["png"]
# Write the Mermaid code to a temporary file
config = config if config else Config.default()
dir_name = os.path.dirname(output_file_without_suffix)
@ -41,7 +55,7 @@ async def mermaid_to_file(
)
return -1
for suffix in ["pdf", "svg", "png"]:
for suffix in suffixes:
output_file = f"{output_file_without_suffix}.{suffix}"
# Call the `mmdc` command to convert the Mermaid code to a PNG
logger.info(f"Generating {output_file}..")
@ -75,15 +89,15 @@ async def mermaid_to_file(
if engine == "playwright":
from metagpt.utils.mmdc_playwright import mermaid_to_file
return await mermaid_to_file(mermaid_code, output_file_without_suffix, width, height)
return await mermaid_to_file(mermaid_code, output_file_without_suffix, width, height, suffixes=suffixes)
elif engine == "pyppeteer":
from metagpt.utils.mmdc_pyppeteer import mermaid_to_file
return await mermaid_to_file(mermaid_code, output_file_without_suffix, width, height)
return await mermaid_to_file(mermaid_code, output_file_without_suffix, width, height, suffixes=suffixes)
elif engine == "ink":
from metagpt.utils.mmdc_ink import mermaid_to_file
return await mermaid_to_file(mermaid_code, output_file_without_suffix)
return await mermaid_to_file(mermaid_code, output_file_without_suffix, suffixes=suffixes)
elif engine == "none":
return 0
else:

View file

@ -6,21 +6,29 @@
@File : mermaid.py
"""
import base64
from typing import List, Optional
from aiohttp import ClientError, ClientSession
from metagpt.logs import logger
async def mermaid_to_file(mermaid_code, output_file_without_suffix):
"""suffix: png/svg
:param mermaid_code: mermaid code
:param output_file_without_suffix: output filename without suffix
:return: 0 if succeed, -1 if failed
async def mermaid_to_file(mermaid_code, output_file_without_suffix, suffixes: Optional[List[str]] = None):
"""Convert Mermaid code to various file formats.
Args:
mermaid_code (str): The Mermaid code to be converted.
output_file_without_suffix (str): The output file name without the suffix.
width (int, optional): The width of the output image. Defaults to 2048.
height (int, optional): The height of the output image. Defaults to 2048.
suffixes (Optional[List[str]], optional): The file suffixes to generate. Supports "png", "pdf", and "svg". Defaults to ["png"].
Returns:
int: 0 if the conversion is successful, -1 if the conversion fails.
"""
encoded_string = base64.b64encode(mermaid_code.encode()).decode()
for suffix in ["svg", "png"]:
suffixes = suffixes or ["png"]
for suffix in suffixes:
output_file = f"{output_file_without_suffix}.{suffix}"
path_type = "svg" if suffix == "svg" else "img"
url = f"https://mermaid.ink/{path_type}/{encoded_string}"

View file

@ -7,6 +7,7 @@
"""
import os
from typing import List, Optional
from urllib.parse import urljoin
from playwright.async_api import async_playwright
@ -14,20 +15,22 @@ from playwright.async_api import async_playwright
from metagpt.logs import logger
async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048, height=2048) -> int:
"""
Converts the given Mermaid code to various output formats and saves them to files.
async def mermaid_to_file(
mermaid_code, output_file_without_suffix, width=2048, height=2048, suffixes: Optional[List[str]] = None
) -> int:
"""Convert Mermaid code to various file formats.
Args:
mermaid_code (str): The Mermaid code to convert.
output_file_without_suffix (str): The output file name without the file extension.
width (int, optional): The width of the output image in pixels. Defaults to 2048.
height (int, optional): The height of the output image in pixels. Defaults to 2048.
mermaid_code (str): The Mermaid code to be converted.
output_file_without_suffix (str): The output file name without the suffix.
width (int, optional): The width of the output image. Defaults to 2048.
height (int, optional): The height of the output image. Defaults to 2048.
suffixes (Optional[List[str]], optional): The file suffixes to generate. Supports "png", "pdf", and "svg". Defaults to ["png"].
Returns:
int: Returns 1 if the conversion and saving were successful, -1 otherwise.
int: 0 if the conversion is successful, -1 if the conversion fails.
"""
suffixes = ["png", "svg", "pdf"]
suffixes = suffixes or ["png"]
__dirname = os.path.dirname(os.path.abspath(__file__))
async with async_playwright() as p:

View file

@ -6,6 +6,7 @@
@File : mmdc_pyppeteer.py
"""
import os
from typing import List, Optional
from urllib.parse import urljoin
from pyppeteer import launch
@ -14,21 +15,24 @@ from metagpt.config2 import Config
from metagpt.logs import logger
async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048, height=2048, config=None) -> int:
"""
Converts the given Mermaid code to various output formats and saves them to files.
async def mermaid_to_file(
mermaid_code, output_file_without_suffix, width=2048, height=2048, config=None, suffixes: Optional[List[str]] = None
) -> int:
"""Convert Mermaid code to various file formats.
Args:
mermaid_code (str): The Mermaid code to convert.
output_file_without_suffix (str): The output file name without the file extension.
width (int, optional): The width of the output image in pixels. Defaults to 2048.
height (int, optional): The height of the output image in pixels. Defaults to 2048.
mermaid_code (str): The Mermaid code to be converted.
output_file_without_suffix (str): The output file name without the suffix.
width (int, optional): The width of the output image. Defaults to 2048.
height (int, optional): The height of the output image. Defaults to 2048.
config (Optional[Config], optional): The configuration to use for the conversion. Defaults to None, which uses the default configuration.
suffixes (Optional[List[str]], optional): The file suffixes to generate. Supports "png", "pdf", and "svg". Defaults to ["png"].
Returns:
int: Returns 1 if the conversion and saving were successful, -1 otherwise.
int: 0 if the conversion is successful, -1 if the conversion fails.
"""
config = config if config else Config.default()
suffixes = ["png", "svg", "pdf"]
suffixes = suffixes or ["png"]
__dirname = os.path.dirname(os.path.abspath(__file__))
if config.mermaid.pyppeteer_path:

View file

@ -1,26 +1,31 @@
import pytest
from metagpt.config2 import Config
from metagpt.configs.exp_pool_config import ExperiencePoolConfig
from metagpt.configs.exp_pool_config import (
ExperiencePoolConfig,
ExperiencePoolRetrievalType,
)
from metagpt.configs.llm_config import LLMConfig
from metagpt.exp_pool.manager import Experience, ExperienceManager
from metagpt.exp_pool.schema import QueryType
from metagpt.exp_pool.schema import DEFAULT_SIMILARITY_TOP_K, QueryType
class TestExperienceManager:
@pytest.fixture
def mock_config(self):
return Config(llm=LLMConfig(), exp_pool=ExperiencePoolConfig(enable_write=True, enable_read=True, enabled=True))
return Config(
llm=LLMConfig(),
exp_pool=ExperiencePoolConfig(
enable_write=True, enable_read=True, enabled=True, retrieval_type=ExperiencePoolRetrievalType.BM25
),
)
@pytest.fixture
def mock_storage(self, mocker):
engine = mocker.MagicMock()
engine.add_objs = mocker.MagicMock()
engine.aretrieve = mocker.AsyncMock(return_value=[])
engine._retriever = mocker.MagicMock()
engine._retriever._vector_store = mocker.MagicMock()
engine._retriever._vector_store._collection = mocker.MagicMock()
engine._retriever._vector_store._collection.count = mocker.MagicMock(return_value=10)
engine.count = mocker.MagicMock(return_value=10)
return engine
@pytest.fixture
@ -29,8 +34,33 @@ class TestExperienceManager:
manager._storage = mock_storage
return manager
def test_vector_store_property(self, exp_manager):
assert exp_manager.vector_store == exp_manager.storage._retriever._vector_store
def test_storage_property(self, exp_manager, mock_storage):
assert exp_manager.storage == mock_storage
def test_storage_property_initialization(self, mocker, mock_config):
mocker.patch.object(ExperienceManager, "_resolve_storage", return_value=mocker.MagicMock())
manager = ExperienceManager(config=mock_config)
assert manager._storage is None
_ = manager.storage
assert manager._storage is not None
def test_create_exp_write_disabled(self, exp_manager, mock_config):
mock_config.exp_pool.enable_write = False
exp = Experience(req="test", resp="response")
exp_manager.create_exp(exp)
exp_manager.storage.add_objs.assert_not_called()
def test_create_exp_write_enabled(self, exp_manager):
exp = Experience(req="test", resp="response")
exp_manager.create_exp(exp)
exp_manager.storage.add_objs.assert_called_once_with([exp])
exp_manager.storage.persist.assert_called_once_with(exp_manager.config.exp_pool.persist_path)
@pytest.mark.asyncio
async def test_query_exps_read_disabled(self, exp_manager, mock_config):
mock_config.exp_pool.enable_read = False
result = await exp_manager.query_exps("query")
assert result == []
@pytest.mark.asyncio
async def test_query_exps_with_exact_match(self, exp_manager, mocker):
@ -65,14 +95,50 @@ class TestExperienceManager:
def test_get_exps_count(self, exp_manager):
assert exp_manager.get_exps_count() == 10
def test_create_exp_write_disabled(self, exp_manager, mock_config):
mock_config.exp_pool.enable_write = False
exp = Experience(req="test", resp="response")
exp_manager.create_exp(exp)
exp_manager.storage.add_objs.assert_not_called()
def test_resolve_storage_bm25(self, mocker, mock_config):
mock_config.exp_pool.retrieval_type = ExperiencePoolRetrievalType.BM25
mocker.patch.object(ExperienceManager, "_create_bm25_storage", return_value=mocker.MagicMock())
manager = ExperienceManager(config=mock_config)
storage = manager._resolve_storage()
manager._create_bm25_storage.assert_called_once()
assert storage is not None
@pytest.mark.asyncio
async def test_query_exps_read_disabled(self, exp_manager, mock_config):
mock_config.exp_pool.enable_read = False
result = await exp_manager.query_exps("query")
assert result == []
def test_resolve_storage_chroma(self, mocker, mock_config):
mock_config.exp_pool.retrieval_type = ExperiencePoolRetrievalType.CHROMA
mocker.patch.object(ExperienceManager, "_create_chroma_storage", return_value=mocker.MagicMock())
manager = ExperienceManager(config=mock_config)
storage = manager._resolve_storage()
manager._create_chroma_storage.assert_called_once()
assert storage is not None
def test_create_bm25_storage(self, mocker, mock_config):
mocker.patch("metagpt.rag.engines.SimpleEngine.from_objs", return_value=mocker.MagicMock())
mocker.patch("metagpt.rag.engines.SimpleEngine.from_index", return_value=mocker.MagicMock())
mocker.patch("metagpt.rag.engines.SimpleEngine.get_obj_nodes", return_value=[])
mocker.patch("metagpt.rag.engines.SimpleEngine._resolve_embed_model", return_value=mocker.MagicMock())
mocker.patch("llama_index.core.VectorStoreIndex", return_value=mocker.MagicMock())
mocker.patch("metagpt.rag.schema.BM25RetrieverConfig", return_value=mocker.MagicMock())
mocker.patch("pathlib.Path.exists", return_value=False)
manager = ExperienceManager(config=mock_config)
storage = manager._create_bm25_storage()
assert storage is not None
def test_create_chroma_storage(self, mocker, mock_config):
mocker.patch("metagpt.rag.engines.SimpleEngine.from_objs", return_value=mocker.MagicMock())
manager = ExperienceManager(config=mock_config)
storage = manager._create_chroma_storage()
assert storage is not None
def test_get_ranker_configs_use_llm_ranker_true(self, mock_config):
mock_config.exp_pool.use_llm_ranker = True
manager = ExperienceManager(config=mock_config)
ranker_configs = manager._get_ranker_configs()
assert len(ranker_configs) == 1
assert ranker_configs[0].top_n == DEFAULT_SIMILARITY_TOP_K
def test_get_ranker_configs_use_llm_ranker_false(self, mock_config):
mock_config.exp_pool.use_llm_ranker = False
manager = ExperienceManager(config=mock_config)
ranker_configs = manager._get_ranker_configs()
assert len(ranker_configs) == 0

View file

@ -75,7 +75,7 @@ class TestSimpleEngine:
)
# Assert
mock_simple_directory_reader.assert_called_once_with(input_dir=input_dir, input_files=input_files)
mock_simple_directory_reader.assert_called_once_with(input_dir=input_dir, input_files=input_files, fs=None)
mock_get_retriever.assert_called_once()
mock_get_rankers.assert_called_once()
mock_get_response_synthesizer.assert_called_once_with(llm=llm)

View file

@ -1,5 +1,6 @@
import pytest
from metagpt.config2 import Config
from metagpt.configs.embedding_config import EmbeddingType
from metagpt.configs.llm_config import LLMType
from metagpt.rag.factories.embedding import RAGEmbeddingFactory
@ -12,7 +13,10 @@ class TestRAGEmbeddingFactory:
@pytest.fixture
def mock_config(self, mocker):
return mocker.patch("metagpt.rag.factories.embedding.config")
config = Config.default().model_copy(deep=True)
default = mocker.patch("metagpt.config2.Config.default")
default.return_value = config
return config
@staticmethod
def mock_openai_embedding(mocker):

View file

@ -8,28 +8,32 @@
import pytest
from metagpt.utils.common import check_cmd_exists
from metagpt.const import DEFAULT_WORKSPACE_ROOT
from metagpt.utils.common import check_cmd_exists, new_transaction_id
from metagpt.utils.mermaid import MMC1, mermaid_to_file
@pytest.mark.asyncio
@pytest.mark.parametrize("engine", ["nodejs", "ink"]) # TODO: playwright and pyppeteer
async def test_mermaid(engine, context, mermaid_mocker):
@pytest.mark.parametrize(
("engine", "suffixes"), [("nodejs", None), ("nodejs", ["png", "svg", "pdf"]), ("ink", None)]
) # TODO: playwright and pyppeteer
async def test_mermaid(engine, suffixes, context, mermaid_mocker):
# nodejs prerequisites: npm install -g @mermaid-js/mermaid-cli
# ink prerequisites: connected to internet
# playwright prerequisites: playwright install --with-deps chromium
assert check_cmd_exists("npm") == 0
save_to = context.git_repo.workdir / f"{engine}/1"
await mermaid_to_file(engine, MMC1, save_to)
save_to = DEFAULT_WORKSPACE_ROOT / f"{new_transaction_id()}/{engine}/1"
await mermaid_to_file(engine, MMC1, save_to, suffixes=suffixes)
# ink does not support pdf
exts = ["." + i for i in suffixes] if suffixes else [".png"]
if engine == "ink":
for ext in [".svg", ".png"]:
for ext in exts:
assert save_to.with_suffix(ext).exists()
save_to.with_suffix(ext).unlink(missing_ok=True)
else:
for ext in [".pdf", ".svg", ".png"]:
for ext in exts:
assert save_to.with_suffix(ext).exists()
save_to.with_suffix(ext).unlink(missing_ok=True)