diff --git a/metagpt/actions/search_enhanced_qa.py b/metagpt/actions/search_enhanced_qa.py index 23f21b2a8..1427f9b19 100644 --- a/metagpt/actions/search_enhanced_qa.py +++ b/metagpt/actions/search_enhanced_qa.py @@ -9,6 +9,7 @@ from pydantic import Field, PrivateAttr, model_validator from metagpt.actions import Action from metagpt.actions.research import CollectLinks, WebBrowseAndSummarize from metagpt.logs import logger +from metagpt.tools.tool_registry import register_tool from metagpt.tools.web_browser_engine import WebBrowserEngine from metagpt.utils.common import CodeParser from metagpt.utils.parse_html import WebPage @@ -57,8 +58,9 @@ Remember, don't blindly repeat the contexts verbatim. And here is the user quest """ +@register_tool(include_functions=["run"]) class SearchEnhancedQA(Action): - """Enhancing question-answering capabilities through search engine augmentation.""" + """Question answering and info searching through search engine.""" name: str = "SearchEnhancedQA" desc: str = "Integrating search engine results to anwser the question." diff --git a/metagpt/base/base_env.py b/metagpt/base/base_env.py index 7da9fd581..361b8b58f 100644 --- a/metagpt/base/base_env.py +++ b/metagpt/base/base_env.py @@ -2,14 +2,18 @@ # -*- coding: utf-8 -*- # @Desc : base environment +import typing from abc import abstractmethod from typing import Any, Optional from metagpt.base.base_env_space import BaseEnvAction, BaseEnvObsParams -from metagpt.schema import Message +from metagpt.base.base_serialization import BaseSerialization + +if typing.TYPE_CHECKING: + from metagpt.schema import Message -class BaseEnvironment: +class BaseEnvironment(BaseSerialization): """Base environment""" @abstractmethod @@ -30,7 +34,7 @@ class BaseEnvironment: """Implement this to feed a action and then get new observation from the env""" @abstractmethod - def publish_message(self, message: Message, peekable: bool = True) -> bool: + def publish_message(self, message: "Message", peekable: bool = True) -> bool: """Distribute the message to the recipients.""" @abstractmethod diff --git a/metagpt/base/base_role.py b/metagpt/base/base_role.py index 2f6c9f963..1f7f00fa2 100644 --- a/metagpt/base/base_role.py +++ b/metagpt/base/base_role.py @@ -1,10 +1,10 @@ from abc import abstractmethod from typing import Optional, Union -from metagpt.schema import Message +from metagpt.base.base_serialization import BaseSerialization -class BaseRole: +class BaseRole(BaseSerialization): """Abstract base class for all roles.""" name: str @@ -24,13 +24,13 @@ class BaseRole: raise NotImplementedError @abstractmethod - async def react(self) -> Message: + async def react(self) -> "Message": """Entry to one of three strategies by which Role reacts to the observed Message.""" @abstractmethod - async def run(self, with_message: Optional[Union[str, Message, list[str]]] = None) -> Optional[Message]: + async def run(self, with_message: Optional[Union[str, "Message", list[str]]] = None) -> Optional["Message"]: """Observe, and think and act based on the results of the observation.""" @abstractmethod - def get_memories(self, k: int = 0) -> list[Message]: + def get_memories(self, k: int = 0) -> list["Message"]: """Return the most recent k memories of this role.""" diff --git a/metagpt/base/base_serialization.py b/metagpt/base/base_serialization.py new file mode 100644 index 000000000..8aff7f39e --- /dev/null +++ b/metagpt/base/base_serialization.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +from typing import Any + +from pydantic import BaseModel, model_serializer, model_validator + + +class BaseSerialization(BaseModel, extra="forbid"): + """ + PolyMorphic subclasses Serialization / Deserialization Mixin + - First of all, we need to know that pydantic is not designed for polymorphism. + - If Engineer is subclass of Role, it would be serialized as Role. If we want to serialize it as Engineer, we need + to add `class name` to Engineer. So we need Engineer inherit SerializationMixin. + + More details: + - https://docs.pydantic.dev/latest/concepts/serialization/ + - https://github.com/pydantic/pydantic/discussions/7008 discuss about avoid `__get_pydantic_core_schema__` + """ + + __is_polymorphic_base = False + __subclasses_map__ = {} + + @model_serializer(mode="wrap") + def __serialize_with_class_type__(self, default_serializer) -> Any: + # default serializer, then append the `__module_class_name` field and return + ret = default_serializer(self) + ret["__module_class_name"] = f"{self.__class__.__module__}.{self.__class__.__qualname__}" + return ret + + @model_validator(mode="wrap") + @classmethod + def __convert_to_real_type__(cls, value: Any, handler): + if isinstance(value, dict) is False: + return handler(value) + + # it is a dict so make sure to remove the __module_class_name + # because we don't allow extra keywords but want to ensure + # e.g Cat.model_validate(cat.model_dump()) works + class_full_name = value.pop("__module_class_name", None) + + # if it's not the polymorphic base we construct via default handler + if not cls.__is_polymorphic_base: + if class_full_name is None: + return handler(value) + elif str(cls) == f"": + return handler(value) + else: + # f"Trying to instantiate {class_full_name} but this is not the polymorphic base class") + pass + + # otherwise we lookup the correct polymorphic type and construct that + # instead + if class_full_name is None: + raise ValueError("Missing __module_class_name field") + + class_type = cls.__subclasses_map__.get(class_full_name, None) + + if class_type is None: + # TODO could try dynamic import + raise TypeError(f"Trying to instantiate {class_full_name}, which has not yet been defined!") + + return class_type(**value) + + def __init_subclass__(cls, is_polymorphic_base: bool = False, **kwargs): + cls.__is_polymorphic_base = is_polymorphic_base + cls.__subclasses_map__[f"{cls.__module__}.{cls.__qualname__}"] = cls + super().__init_subclass__(**kwargs) diff --git a/metagpt/configs/embedding_config.py b/metagpt/configs/embedding_config.py index 20de47999..f9b41b9dc 100644 --- a/metagpt/configs/embedding_config.py +++ b/metagpt/configs/embedding_config.py @@ -20,11 +20,13 @@ class EmbeddingConfig(YamlModel): --------- api_type: "openai" api_key: "YOU_API_KEY" + dimensions: "YOUR_MODEL_DIMENSIONS" api_type: "azure" api_key: "YOU_API_KEY" base_url: "YOU_BASE_URL" api_version: "YOU_API_VERSION" + dimensions: "YOUR_MODEL_DIMENSIONS" api_type: "gemini" api_key: "YOU_API_KEY" @@ -32,6 +34,7 @@ class EmbeddingConfig(YamlModel): api_type: "ollama" base_url: "YOU_BASE_URL" model: "YOU_MODEL" + dimensions: "YOUR_MODEL_DIMENSIONS" """ api_type: Optional[EmbeddingType] = None @@ -41,6 +44,7 @@ class EmbeddingConfig(YamlModel): model: Optional[str] = None embed_batch_size: Optional[int] = None + dimensions: Optional[int] = None # output dimension of embedding model @field_validator("api_type", mode="before") @classmethod diff --git a/metagpt/prompts/di/data_analyst.py b/metagpt/prompts/di/data_analyst.py index 8e5b888d3..9f943b187 100644 --- a/metagpt/prompts/di/data_analyst.py +++ b/metagpt/prompts/di/data_analyst.py @@ -1,12 +1,12 @@ from metagpt.strategy.task_type import TaskType EXTRA_INSTRUCTION = """ -6. Carefully choose to use or not use the browser tool to assist you in web tasks. - - When no click action is required, no need to use the Browser tool to navigate to the webpage before scraping. - - Write code to view the HTML content rather than using the Browser tool. - - Make sure the command_name are certainly in Available Commands when you use the Browser tool. - - For information searching requirement, you should use the Browser tool instead of web scraping. - - When no link is provided, you should use the Browser tool to search for the information. +6. Carefully consider how you handle web tasks: + - Use SearchEnhancedQA for general information searching, i.e. querying search engines, such as googling news, weather, wiki, etc. Usually, no link is provided. + - Use Browser for reading, navigating, or in-domain searching within a specific web, such as reading a blog, searching products from a given e-commerce web link, or interacting with a web app. + - Use DataAnalyst.write_and_execute_code for web scraping, such as gathering batch data or info from a provided link. + - Write code to view the HTML content rather than using the Browser tool. + - Make sure the command_name are certainly in Available Commands when you use the Browser tool. 7. When you are making plan. It is highly recommend to plan and append all the tasks in first response once time, except for 7.1. 7.1. When the requirement is inquiring about a pdf, docx, md, or txt document, read the document first through either Editor.read WITHOUT a plan. After reading the document, use RoleZero.reply_to_human if the requirement can be answered straightaway, otherwise, make a plan if further calculation is needed. 8. Don't finish_current_task multiple times for the same task. diff --git a/metagpt/prompts/di/role_zero.py b/metagpt/prompts/di/role_zero.py index 956f26834..3029735ba 100644 --- a/metagpt/prompts/di/role_zero.py +++ b/metagpt/prompts/di/role_zero.py @@ -79,7 +79,7 @@ Output should adhere to the following format. ```json [ {{ - "command_name": str, + "command_name": "ClassName.method_name" or "function_name", "args": {{"arg_name": arg_value, ...}} }}, ... diff --git a/metagpt/prompts/di/team_leader.py b/metagpt/prompts/di/team_leader.py index 3ba9a8b0d..e5c119dc8 100644 --- a/metagpt/prompts/di/team_leader.py +++ b/metagpt/prompts/di/team_leader.py @@ -14,8 +14,15 @@ Pay close attention to new user message, review the conversation history, use Ro Pay close attention to messages from team members. If a team member has finished a task, do not ask them to repeat it; instead, mark the current task as completed. Note: 1. If the requirement is a pure DATA-RELATED requirement, such as web browsing, web scraping, web searching, web imitation, data science, data analysis, machine learning, deep learning, text-to-image etc. DON'T decompose it, assign a single task with the original user requirement as instruction directly to Data Analyst. -2. If the requirement is developing a software, game, app, or website, excluding the above data-related tasks, you should decompose the requirement into multiple tasks and assign them to different team members based on their expertise. The software default development process has four steps: creating a Product Requirement Document (PRD) by the Product Manager -> writing a System Design by the Architect -> creating tasks by the Project Manager -> and coding by the Engineer. You may choose to execute any of these steps. When publishing message to Product Manager, you should directly copy the full original user requirement. +2. If the requirement is developing a software, game, app, or website, excluding the above data-related tasks, you should decompose the requirement into multiple tasks and assign them to different team members based on their expertise. The standard software development process has four steps: creating a Product Requirement Document (PRD) by the Product Manager -> writing a System Design by the Architect -> creating tasks by the Project Manager -> and coding by the Engineer. You may choose to execute any of these steps. When publishing message to Product Manager, you should directly copy the full original user requirement. 2.1. If the requirement contains both DATA-RELATED part mentioned in 1 and software development part mentioned in 2, you should decompose the software development part and assign them to different team members based on their expertise, and assign the DATA-RELATED part to Data Analyst David directly. +2.2. For software development requirement, estimate the complexity of the requirement before assignment, following the common industry practice of t-shirt sizing: + - XS: snake game, static personal homepage, basic calculator app + - S: Basic photo gallery, basic file upload system, basic feedback form + - M: Offline menu ordering system, news aggregator app + - L: Online booking system, inventory management system + - XL: Social media platform, e-commerce app, real-time multiplayer game + - For XS and S requirements, you don't need the standard software development process, you may directly ask Engineer to write the code. Otherwise, estimate if any part of the standard software development process may contribute to a better final code. If so, assign team members accordingly. 3.1 If the task involves code review (CR) or code checking, you should assign it to Engineer. 3.2. If the requirement is to fix a bug or issue, you should assign it to Issue Solver. However, if the code is written by Engineer, Engineer must maintain the code. 4. If the requirement is a common-sense, logical, or math problem, you should respond directly without assigning any task to team members. diff --git a/metagpt/provider/base_llm.py b/metagpt/provider/base_llm.py index 75d8bfe00..f9111ffe0 100644 --- a/metagpt/provider/base_llm.py +++ b/metagpt/provider/base_llm.py @@ -27,7 +27,6 @@ from metagpt.configs.llm_config import LLMConfig from metagpt.const import IMAGES, LLM_API_TIMEOUT, USE_CONFIG_TIMEOUT from metagpt.logs import logger from metagpt.provider.constant import MULTI_MODAL_MODELS -from metagpt.schema import Message from metagpt.utils.common import log_and_reraise from metagpt.utils.cost_manager import CostManager, Costs from metagpt.utils.token_counter import TOKEN_MAX @@ -80,7 +79,7 @@ class BaseLLM(ABC): def support_image_input(self) -> bool: return any([m in self.config.model for m in MULTI_MODAL_MODELS]) - def format_msg(self, messages: Union[str, Message, list[dict], list[Message], list[str]]) -> list[dict]: + def format_msg(self, messages: Union[str, "Message", list[dict], list["Message"], list[str]]) -> list[dict]: """convert messages to list[dict].""" from metagpt.schema import Message @@ -173,7 +172,9 @@ class BaseLLM(ABC): context.append(self._assistant_msg(rsp_text)) return self._extract_assistant_rsp(context) - async def aask_code(self, messages: Union[str, Message, list[dict]], timeout=USE_CONFIG_TIMEOUT, **kwargs) -> dict: + async def aask_code( + self, messages: Union[str, "Message", list[dict]], timeout=USE_CONFIG_TIMEOUT, **kwargs + ) -> dict: raise NotImplementedError @abstractmethod diff --git a/metagpt/provider/google_gemini_api.py b/metagpt/provider/google_gemini_api.py index e4b3a3f17..5c1b92503 100644 --- a/metagpt/provider/google_gemini_api.py +++ b/metagpt/provider/google_gemini_api.py @@ -22,7 +22,6 @@ from metagpt.const import USE_CONFIG_TIMEOUT from metagpt.logs import log_llm_stream, logger from metagpt.provider.base_llm import BaseLLM from metagpt.provider.llm_provider_registry import register_provider -from metagpt.schema import Message class GeminiGenerativeModel(GenerativeModel): @@ -73,7 +72,7 @@ class GeminiLLM(BaseLLM): def _system_msg(self, msg: str) -> dict[str, str]: return {"role": "user", "parts": [msg]} - def format_msg(self, messages: Union[str, Message, list[dict], list[Message], list[str]]) -> list[dict]: + def format_msg(self, messages: Union[str, "Message", list[dict], list["Message"], list[str]]) -> list[dict]: """convert messages to list[dict].""" from metagpt.schema import Message diff --git a/metagpt/rag/engines/simple.py b/metagpt/rag/engines/simple.py index e48decdab..8d78fcad7 100644 --- a/metagpt/rag/engines/simple.py +++ b/metagpt/rag/engines/simple.py @@ -2,7 +2,8 @@ import json import os -from typing import Any, Optional, Union +from pathlib import Path +from typing import Any, List, Optional, Set, Union import fsspec from llama_index.core import SimpleDirectoryReader @@ -78,6 +79,7 @@ class SimpleEngine(RetrieverQueryEngine): callback_manager=callback_manager, ) self._transformations = transformations or self._default_transformations() + self._filenames = set() @classmethod def from_docs( @@ -192,11 +194,11 @@ class SimpleEngine(RetrieverQueryEngine): self._try_reconstruct_obj(nodes) return nodes - def add_docs(self, input_files: list[str]): + def add_docs(self, input_files: List[Union[str, Path]]): """Add docs to retriever. retriever must has add_nodes func.""" self._ensure_retriever_modifiable() - documents = SimpleDirectoryReader(input_files=input_files).load_data() + documents = SimpleDirectoryReader(input_files=[str(i) for i in input_files]).load_data() self._fix_document_metadata(documents) nodes = run_transformations(documents, transformations=self._transformations) @@ -227,6 +229,24 @@ class SimpleEngine(RetrieverQueryEngine): return self.retriever.clear(**kwargs) + def delete_docs(self, input_files: List[Union[str, Path]]): + """Delete documents from the index and document store. + + Args: + input_files (List[Union[str, Path]]): A list of file paths or file names to be deleted. + + Raises: + NotImplementedError: If the method is not implemented. + """ + exists_filenames = set() + filenames = {str(i) for i in input_files} + for doc_id, info in self.retriever._index.ref_doc_info.items(): + if info.metadata.get("file_path") in filenames: + exists_filenames.add(doc_id) + + for doc_id in exists_filenames: + self.retriever._index.delete_ref_doc(doc_id, delete_from_docstore=True) + @staticmethod def get_obj_nodes(objs: Optional[list[RAGObject]] = None) -> list[ObjectNode]: """Converts a list of RAGObjects to a list of ObjectNodes.""" @@ -333,3 +353,7 @@ class SimpleEngine(RetrieverQueryEngine): @staticmethod def _default_transformations(): return [SentenceSplitter()] + + @property + def filenames(self) -> Set[str]: + return self._filenames diff --git a/metagpt/rag/factories/embedding.py b/metagpt/rag/factories/embedding.py index d647883bd..19b8b36f6 100644 --- a/metagpt/rag/factories/embedding.py +++ b/metagpt/rag/factories/embedding.py @@ -5,9 +5,6 @@ from typing import Any, Optional from llama_index.core.embeddings import BaseEmbedding from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding -from llama_index.embeddings.gemini import GeminiEmbedding -from llama_index.embeddings.ollama import OllamaEmbedding -from llama_index.embeddings.openai import OpenAIEmbedding from metagpt.config2 import Config from metagpt.configs.embedding_config import EmbeddingType @@ -49,7 +46,9 @@ class RAGEmbeddingFactory(GenericFactory): raise TypeError("To use RAG, please set your embedding in config2.yaml.") - def _create_openai(self) -> OpenAIEmbedding: + def _create_openai(self) -> "OpenAIEmbedding": + from llama_index.embeddings.openai import OpenAIEmbedding + params = dict( api_key=self.config.embedding.api_key or self.config.llm.api_key, api_base=self.config.embedding.base_url or self.config.llm.base_url, @@ -70,7 +69,9 @@ class RAGEmbeddingFactory(GenericFactory): return AzureOpenAIEmbedding(**params) - def _create_gemini(self) -> GeminiEmbedding: + def _create_gemini(self) -> "GeminiEmbedding": + from llama_index.embeddings.gemini import GeminiEmbedding + params = dict( api_key=self.config.embedding.api_key, api_base=self.config.embedding.base_url, @@ -80,7 +81,9 @@ class RAGEmbeddingFactory(GenericFactory): return GeminiEmbedding(**params) - def _create_ollama(self) -> OllamaEmbedding: + def _create_ollama(self) -> "OllamaEmbedding": + from llama_index.embeddings.ollama import OllamaEmbedding + params = dict( base_url=self.config.embedding.base_url, ) diff --git a/metagpt/rag/factories/llm.py b/metagpt/rag/factories/llm.py index 59f6db4d9..bd252771a 100644 --- a/metagpt/rag/factories/llm.py +++ b/metagpt/rag/factories/llm.py @@ -13,7 +13,6 @@ from llama_index.core.llms.callbacks import llm_completion_callback from pydantic import Field from metagpt.config2 import Config -from metagpt.llm import LLM from metagpt.provider.base_llm import BaseLLM from metagpt.utils.async_helper import NestAsyncio from metagpt.utils.token_counter import TOKEN_MAX @@ -79,4 +78,6 @@ class RAGLLM(CustomLLM): def get_rag_llm(model_infer: BaseLLM = None) -> RAGLLM: """Get llm that can be used by LlamaIndex.""" + from metagpt.llm import LLM + return RAGLLM(model_infer=model_infer or LLM()) diff --git a/metagpt/roles/di/data_analyst.py b/metagpt/roles/di/data_analyst.py index 329b3c45d..f9bead1ac 100644 --- a/metagpt/roles/di/data_analyst.py +++ b/metagpt/roles/di/data_analyst.py @@ -30,8 +30,8 @@ class DataAnalyst(RoleZero): instruction: str = ROLE_INSTRUCTION + EXTRA_INSTRUCTION task_type_desc: str = TASK_TYPE_DESC - tools: list[str] = ["Plan", "DataAnalyst", "RoleZero", "Browser", "Editor:write,read"] - custom_tools: list[str] = ["web scraping", "Terminal"] + tools: list[str] = ["Plan", "DataAnalyst", "RoleZero", "Browser", "Editor:write,read", "SearchEnhancedQA"] + custom_tools: list[str] = ["web scraping", "Terminal", "Editor:write,read"] custom_tool_recommender: ToolRecommender = None experience_retriever: Annotated[ExpRetriever, Field(exclude=True)] = KeywordExpRetriever() diff --git a/metagpt/roles/di/engineer2.py b/metagpt/roles/di/engineer2.py index bee5aa04d..92ecb633d 100644 --- a/metagpt/roles/di/engineer2.py +++ b/metagpt/roles/di/engineer2.py @@ -40,9 +40,10 @@ class Engineer2(RoleZero): "Plan", "Editor", "RoleZero", - "Terminal", + "Terminal:run_command", "Browser:goto,scroll", "git_create_pull", + "SearchEnhancedQA", "Engineer2", ] # SWE Agent parameter diff --git a/metagpt/roles/di/role_zero.py b/metagpt/roles/di/role_zero.py index d28f27138..0e8d005e7 100644 --- a/metagpt/roles/di/role_zero.py +++ b/metagpt/roles/di/role_zero.py @@ -31,7 +31,6 @@ from metagpt.prompts.di.role_zero import ( ROLE_INSTRUCTION, SUMMARY_PROMPT, SYSTEM_PROMPT, - THOUGHT_GUIDANCE, ) from metagpt.roles import Role from metagpt.schema import AIMessage, Message, UserMessage @@ -62,7 +61,6 @@ class RoleZero(Role): system_prompt: str = SYSTEM_PROMPT # Use None to conform to the default value at llm.aask cmd_prompt: str = CMD_PROMPT cmd_prompt_current_state: str = "" - thought_guidance: str = THOUGHT_GUIDANCE instruction: str = ROLE_INSTRUCTION task_type_desc: Optional[str] = None @@ -90,7 +88,7 @@ class RoleZero(Role): # Others command_rsp: str = "" # the raw string containing the commands commands: list[dict] = [] # commands to be executed - memory_k: int = 20 # number of memories (messages) to use as historical context + memory_k: int = 100 # number of memories (messages) to use as historical context use_fixed_sop: bool = False requirements_constraints: str = "" # the constraints in user requirements use_summary: bool = True # whether to summarize at the end @@ -120,6 +118,7 @@ class RoleZero(Role): "Plan.replace_task": self.planner.plan.replace_task, "RoleZero.ask_human": self.ask_human, "RoleZero.reply_to_human": self.reply_to_human, + "SearchEnhancedQA.run": SearchEnhancedQA().run, } self.tool_execution_map.update( { diff --git a/metagpt/schema.py b/metagpt/schema.py index 201ff4357..ce64d130a 100644 --- a/metagpt/schema.py +++ b/metagpt/schema.py @@ -34,10 +34,9 @@ from pydantic import ( create_model, field_serializer, field_validator, - model_serializer, - model_validator, ) +from metagpt.base.base_serialization import BaseSerialization from metagpt.const import ( AGENT, MESSAGE_ROUTE_CAUSE_BY, @@ -69,67 +68,7 @@ from metagpt.utils.serialize import ( ) -class SerializationMixin(BaseModel, extra="forbid"): - """ - PolyMorphic subclasses Serialization / Deserialization Mixin - - First of all, we need to know that pydantic is not designed for polymorphism. - - If Engineer is subclass of Role, it would be serialized as Role. If we want to serialize it as Engineer, we need - to add `class name` to Engineer. So we need Engineer inherit SerializationMixin. - - More details: - - https://docs.pydantic.dev/latest/concepts/serialization/ - - https://github.com/pydantic/pydantic/discussions/7008 discuss about avoid `__get_pydantic_core_schema__` - """ - - __is_polymorphic_base = False - __subclasses_map__ = {} - - @model_serializer(mode="wrap") - def __serialize_with_class_type__(self, default_serializer) -> Any: - # default serializer, then append the `__module_class_name` field and return - ret = default_serializer(self) - ret["__module_class_name"] = f"{self.__class__.__module__}.{self.__class__.__qualname__}" - return ret - - @model_validator(mode="wrap") - @classmethod - def __convert_to_real_type__(cls, value: Any, handler): - if isinstance(value, dict) is False: - return handler(value) - - # it is a dict so make sure to remove the __module_class_name - # because we don't allow extra keywords but want to ensure - # e.g Cat.model_validate(cat.model_dump()) works - class_full_name = value.pop("__module_class_name", None) - - # if it's not the polymorphic base we construct via default handler - if not cls.__is_polymorphic_base: - if class_full_name is None: - return handler(value) - elif str(cls) == f"": - return handler(value) - else: - # f"Trying to instantiate {class_full_name} but this is not the polymorphic base class") - pass - - # otherwise we lookup the correct polymorphic type and construct that - # instead - if class_full_name is None: - raise ValueError("Missing __module_class_name field") - - class_type = cls.__subclasses_map__.get(class_full_name, None) - - if class_type is None: - # TODO could try dynamic import - raise TypeError(f"Trying to instantiate {class_full_name}, which has not yet been defined!") - - return class_type(**value) - - def __init_subclass__(cls, is_polymorphic_base: bool = False, **kwargs): - cls.__is_polymorphic_base = is_polymorphic_base - cls.__subclasses_map__[f"{cls.__module__}.{cls.__qualname__}"] = cls - super().__init_subclass__(**kwargs) - +class SerializationMixin(BaseSerialization): @handle_exception def serialize(self, file_path: str = None) -> str: """Serializes the current instance to a JSON file. diff --git a/metagpt/tools/libs/editor.py b/metagpt/tools/libs/editor.py index 12af8611f..e7fc5f0a1 100644 --- a/metagpt/tools/libs/editor.py +++ b/metagpt/tools/libs/editor.py @@ -724,8 +724,7 @@ class Editor(BaseModel): return ret_str def edit_file_by_replace(self, file_name: str, to_replace: str, new_content: str) -> str: - """ - Edit a file. This will search for `to_replace` in the given file and replace it with `new_content`. + """Edit a file. This will search for `to_replace` in the given file and replace it with `new_content`. Every *to_replace* must *EXACTLY MATCH* the existing source code, character for character, including all comments, docstrings, etc. diff --git a/metagpt/tools/libs/index_repo.py b/metagpt/tools/libs/index_repo.py new file mode 100644 index 000000000..fadc11522 --- /dev/null +++ b/metagpt/tools/libs/index_repo.py @@ -0,0 +1,264 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import json +from pathlib import Path +from typing import Dict, List, Optional, Set, Union + +import tiktoken +from llama_index.core.base.embeddings.base import BaseEmbedding +from llama_index.core.schema import NodeWithScore +from pydantic import BaseModel, Field, model_validator + +from metagpt.config2 import Config +from metagpt.logs import logger +from metagpt.rag.engines import SimpleEngine +from metagpt.rag.factories.embedding import RAGEmbeddingFactory +from metagpt.rag.schema import FAISSIndexConfig, FAISSRetrieverConfig, LLMRankerConfig +from metagpt.utils.common import aread, awrite, generate_fingerprint, list_files +from metagpt.utils.repo_to_markdown import is_text_file + + +class TextScore(BaseModel): + filename: str + text: str + score: Optional[float] = None + + +class IndexRepo(BaseModel): + persist_path: str # The persist path of the index repo, {DEFAULT_WORKSPACE_ROOT}/.index/{chat_id or 'uploads'}/ + root_path: str # `/data/uploads` or r`/data/chats/\d+`, the root path of files indexed by the index repo. + fingerprint_filename: str = "fingerprint.json" + model: Optional[str] = None + min_token_count: int = 10000 + max_token_count: int = 100000000 + recall_count: int = 5 + embedding: Optional[BaseEmbedding] = Field(default=None, exclude=True) + fingerprints: Dict[str, str] = Field(default_factory=dict) + + @model_validator(mode="after") + def _update_fingerprints(self) -> "IndexRepo": + """Load fingerprints from the fingerprint file if not already loaded. + + Returns: + IndexRepo: The updated IndexRepo instance. + """ + if not self.fingerprints: + filename = Path(self.persist_path) / self.fingerprint_filename + if not filename.exists(): + return self + with open(str(filename), "r") as reader: + self.fingerprints = json.load(reader) + return self + + async def search( + self, query: str, filenames: Optional[List[Path]] = None + ) -> Optional[List[Union[NodeWithScore, TextScore]]]: + """Search for documents related to the given query. + + Args: + query (str): The search query. + filenames (Optional[List[Path]]): A list of filenames to filter the search. + + Returns: + Optional[List[Union[NodeWithScore, TextScore]]]: A list of search results containing NodeWithScore or TextScore. + """ + encoding = tiktoken.get_encoding("cl100k_base") + result: List[Union[NodeWithScore, TextScore]] = [] + filenames, _ = await self._filter(filenames) + filter_filenames = set() + for i in filenames: + content = await aread(filename=i) + token_count = len(encoding.encode(content)) + if not self._is_buildable(token_count): + result.append(TextScore(filename=str(i), text=content)) + continue + file_fingerprint = generate_fingerprint(content) + if self.fingerprints.get(str(i)) != file_fingerprint: + logger.error(f'file: "{i}" changed but not indexed') + continue + filter_filenames.add(str(i)) + nodes = await self._search(query=query, filters=filter_filenames) + return result + nodes + + async def merge( + self, query: str, indices_list: List[List[Union[NodeWithScore, TextScore]]] + ) -> List[Union[NodeWithScore, TextScore]]: + """Merge results from multiple indices based on the query. + + Args: + query (str): The search query. + indices_list (List[List[Union[NodeWithScore, TextScore]]]): A list of result lists from different indices. + + Returns: + List[Union[NodeWithScore, TextScore]]: A list of merged results sorted by similarity. + """ + if not self.embedding: + config = Config.default() + if self.model: + config.embedding.model = self.model + factory = RAGEmbeddingFactory(config) + self.embedding = factory.get_rag_embedding() + + scores = [] + query_embedding = await self.embedding.aget_text_embedding(query) + flat_nodes = [node for indices in indices_list for node in indices] + for i in flat_nodes: + text_embedding = await self.embedding.aget_text_embedding(i.text) + similarity = self.embedding.similarity(query_embedding, text_embedding) + scores.append((similarity, i)) + scores.sort(key=lambda x: x[0], reverse=True) + return [i[1] for i in scores][: self.recall_count] + + async def add(self, paths: List[Path]): + """Add new documents to the index. + + Args: + paths (List[Path]): A list of paths to the documents to be added. + """ + encoding = tiktoken.get_encoding("cl100k_base") + filenames, _ = await self._filter(paths) + filter_filenames = [] + delete_filenames = [] + for i in filenames: + content = await aread(filename=i) + if not self._is_fingerprint_changed(filename=i, content=content): + continue + token_count = len(encoding.encode(content)) + if self._is_buildable(token_count): + filter_filenames.append(i) + logger.debug(f"{i} is_buildable: {token_count}, {self.min_token_count}~{self.max_token_count}") + else: + delete_filenames.append(i) + logger.debug(f"{i} not is_buildable: {token_count}, {self.min_token_count}~{self.max_token_count}") + await self._add_batch(filenames=filter_filenames, delete_filenames=delete_filenames) + + async def _add_batch(self, filenames: List[Union[str, Path]], delete_filenames: List[Union[str, Path]]): + """Add and remove documents in a batch operation. + + Args: + filenames (List[Union[str, Path]]): List of filenames to add. + delete_filenames (List[Union[str, Path]]): List of filenames to delete. + """ + if not filenames: + return + logger.info(f"update index repo, add {filenames}, remove {delete_filenames}") + engine = None + if Path(self.persist_path).exists(): + logger.debug(f"load index from {self.persist_path}") + engine = SimpleEngine.from_index( + index_config=FAISSIndexConfig(persist_path=self.persist_path), + retriever_configs=[FAISSRetrieverConfig()], + ) + try: + engine.delete_docs(filenames + delete_filenames) + logger.debug(f"delete docs {filenames + delete_filenames}") + engine.add_docs(input_files=filenames) + logger.debug(f"add docs {filenames}") + except NotImplementedError as e: + logger.debug(f"{e}") + filenames = list(set([str(i) for i in filenames] + list(self.fingerprints.keys()))) + engine = None + logger.info(f"{e}. Rebuild all.") + if not engine: + engine = SimpleEngine.from_docs( + input_files=[str(i) for i in filenames], + retriever_configs=[FAISSRetrieverConfig()], + ranker_configs=[LLMRankerConfig()], + ) + logger.debug(f"add docs {filenames}") + engine.persist(persist_dir=self.persist_path) + for i in filenames: + content = await aread(i) + fp = generate_fingerprint(content) + self.fingerprints[str(i)] = fp + await awrite(filename=Path(self.persist_path) / self.fingerprint_filename, data=json.dumps(self.fingerprints)) + + def __str__(self): + """Return a string representation of the IndexRepo. + + Returns: + str: The filename of the index repository. + """ + return f"{self.persist_path}" + + def _is_buildable(self, token_count: int) -> bool: + """Check if the token count is within the buildable range. + + Args: + token_count (int): The number of tokens in the content. + + Returns: + bool: True if buildable, False otherwise. + """ + if token_count < self.min_token_count or token_count > self.max_token_count: + return False + return True + + async def _filter(self, filenames: Optional[List[Union[str, Path]]] = None) -> (List[Path], List[Path]): + """Filter the provided filenames to only include valid text files. + + Args: + filenames (Optional[List[Union[str, Path]]]): List of filenames to filter. + + Returns: + Tuple[List[Path], List[Path]]: A tuple containing a list of valid pathnames and a list of excluded paths. + """ + root_path = Path(self.root_path).absolute() + if not filenames: + filenames = [root_path] + pathnames = [] + excludes = [] + for i in filenames: + path = Path(i).absolute() + if not path.is_relative_to(root_path): + excludes.append(path) + logger.debug(f"{path} not is_relative_to {root_path})") + continue + if not path.is_dir(): + is_text, _ = await is_text_file(path) + if is_text: + pathnames.append(path) + continue + subfiles = list_files(path) + for j in subfiles: + is_text, _ = await is_text_file(j) + if is_text: + pathnames.append(j) + + logger.debug(f"{pathnames}, excludes:{excludes})") + return pathnames, excludes + + async def _search(self, query: str, filters: Set[str]) -> List[NodeWithScore]: + """Perform a search for the given query using the index. + + Args: + query (str): The search query. + filters (Set[str]): A set of filenames to filter the search results. + + Returns: + List[NodeWithScore]: A list of nodes with scores matching the query. + """ + if not Path(self.persist_path).exists(): + return [] + engine = SimpleEngine.from_index( + index_config=FAISSIndexConfig(persist_path=self.persist_path), retriever_configs=[FAISSRetrieverConfig()] + ) + rsp = await engine.aretrieve(query) + return [i for i in rsp if i.metadata.get("file_path") in filters] + + def _is_fingerprint_changed(self, filename: Union[str, Path], content: str) -> bool: + """Check if the fingerprint of the given document content has changed. + + Args: + filename (Union[str, Path]): The filename of the document. + content (str): The content of the document. + + Returns: + bool: True if the fingerprint has changed, False otherwise. + """ + old_fp = self.fingerprints.get(str(filename)) + if not old_fp: + return True + fp = generate_fingerprint(content) + return old_fp != fp diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py index 42a872c76..90f13da23 100644 --- a/metagpt/utils/common.py +++ b/metagpt/utils/common.py @@ -16,6 +16,7 @@ import base64 import contextlib import csv import functools +import hashlib import importlib import inspect import json @@ -889,7 +890,7 @@ async def get_mime_type(filename: str | Path, force_read: bool = False) -> str: } try: - stdout, stderr, _ = await shell_execute(f"file --mime-type {str(filename)}") + stdout, stderr, _ = await shell_execute(f"file --mime-type '{str(filename)}'") if stderr: logger.debug(f"file:{filename}, error:{stderr}") return guess_mime_type @@ -1175,3 +1176,23 @@ def rectify_pathname(path: Union[str, Path], default_filename: str) -> Path: else: output_pathname.parent.mkdir(parents=True, exist_ok=True) return output_pathname + + +def generate_fingerprint(text: str) -> str: + """ + Generate a fingerprint for the given text + + Args: + text (str): The text for which the fingerprint needs to be generated + + Returns: + str: The fingerprint value of the text + """ + text_bytes = text.encode("utf-8") + + # calculate SHA-256 hash + sha256 = hashlib.sha256() + sha256.update(text_bytes) + fingerprint = sha256.hexdigest() + + return fingerprint diff --git a/tests/data/embedding/2.answer.md b/tests/data/embedding/2.answer.md new file mode 100644 index 000000000..3807f03c1 --- /dev/null +++ b/tests/data/embedding/2.answer.md @@ -0,0 +1,2 @@ +检索结果 +法务查询者可根据国际小超人钉钉小程序UI上的滚筒切换业务线 这张图片展示了一个移动应用的界面,界面标题为“法律意见详情”。用户可以根据具体情况切换业务线。界面中有多个字段,包括“国家名称”、“国家情况描述”、“业务线”、“产品法规分析”和“签约主体”。第一张截图显示了详细的法律情报信息,包含区域名称、区域情况描述、业务线和产品法规概述等字段。第二张截图显示了“法律意见详情”界面,其中列出了国家名称、国家情况描述、业务线、产品法规分析和签约主体。第三张截图与第二张相似,但显示了选项的可选择状态。最下方有“取消”和“确定”的按钮。 法务查询者从国家详情中的业务线名列表中选出要查看的业务线。 \ No newline at end of file diff --git a/tests/data/embedding/2.knowledge.md b/tests/data/embedding/2.knowledge.md new file mode 100644 index 000000000..615614098 --- /dev/null +++ b/tests/data/embedding/2.knowledge.md @@ -0,0 +1,25 @@ +## Textual User Requirements + +### 3.2. 首页 + +首页有两个分区,上面部分是法律意见检索栏。 + +法务查询者第一次进入国际小超人钉钉小程序展示引导页,以后进入不再展示,点击「我知道了」引导页消失。 + +#### 首页 +![首页](1.png) +这是一个名为“法务小超人”的移动应用程序的界面截图。界面顶部显示了应用名称和一个可切换语言的按钮“English”。在界面中间部分,有一个标题“法律意见查询”,以及一个搜索框,提示输入国家名称以查询法律意见。下方显示已收录法律意见8394篇。界面下半部分是“法务 Q&A”部分,列出了一些法律相关的选项,例如“国际法务接入口人”、“国内法务接入口人”、“国际法律协议合同办理指引”和“国内法律协议合同办理指引”。界面底部有三个导航按钮,分别是“首页”、“模板”和“我的”。 + +#### 按国家名维度搜索 +法务查询者在国际小超人钉钉小程序的搜索框中进行检索时采用typeahead,只能下拉选择法务中台中有的国家名称。 +![按国家名维度搜索](2.png) +在这张图像中,用户正在一个名为“法律意见查询”的应用中进行国家名称的搜索。用户在搜索框中输入国家名称时,系统会提供下拉建议。这些建议基于 typeahead 功能,从法务中台中筛选出匹配的国家名称供用户选择。目前,搜索结果包含了“中国”和“菲律宾”两个具体的国家名称,其它显示为“国家名”。用户可以通过下拉菜单快速选择所需的国家名称。 + +#### 检索结果 +法务查询者可根据国际小超人钉钉小程序UI上的滚筒切换业务线 +![检索结果](3.png) +这张图片展示了一个移动应用的界面,界面标题为“法律意见详情”。用户可以根据具体情况切换业务线。界面中有多个字段,包括“国家名称”、“国家情况描述”、“业务线”、“产品法规分析”和“签约主体”。第一张截图显示了详细的法律情报信息,包含区域名称、区域情况描述、业务线和产品法规概述等字段。第二张截图显示了“法律意见详情”界面,其中列出了国家名称、国家情况描述、业务线、产品法规分析和签约主体。第三张截图与第二张相似,但显示了选项的可选择状态。最下方有“取消”和“确定”的按钮。 +法务查询者从国家详情中的业务线名列表中选出要查看的业务线。 + +#### 查看法律意见详情 +国际小超人钉钉小程序用国家代码和业务代码做参数,查询法律意见详情,然后将法律意见详情展示给法务查询者。 \ No newline at end of file diff --git a/tests/data/embedding/2.query.md b/tests/data/embedding/2.query.md new file mode 100644 index 000000000..ba470b8bd --- /dev/null +++ b/tests/data/embedding/2.query.md @@ -0,0 +1 @@ +业务线UI有哪些操作? \ No newline at end of file diff --git a/tests/data/embedding/3.answer.md b/tests/data/embedding/3.answer.md new file mode 100644 index 000000000..35b0c6899 --- /dev/null +++ b/tests/data/embedding/3.answer.md @@ -0,0 +1,7 @@ +国家/区域导游详情 & 法律意见详情 查询 +Description:根据国家code查询国家/区域导游信息详情 +ID: 8 +HTTP METHOD: GET +Endpoint: /contract/country/navigate.json +Input Parameters: |名称|描述|类型(长度)|必选|备注| | :- | :- | :-: | :- | :- | |countryCode|国家code|string|√|| +Returns: |名称|描述|类型(长度)|必选|备注| | :- | :- | :-: | :- | :- | |success|业务处理成功true,否则false|boolean|√|只判断这个属性即可| |message|错误信息,可以用来提示|string|√|| |code|返回状态码|string|√|| |data|国家/区域导游详情|object|√|| |-> country||||| |-> -> id|id|integer|√|| |-> -> country|国家code|string|√|| |-> -> countryName|国家中文名称|string|√|| |-> -> countryNameEn|国家英文名称|string|√|| |-> -> content|国家导游中文详情json数组,具体格式见下示例|list of object|√|| |-> -> -> title|标题|object|√|| |-> -> -> -> title|中文标题|string||| |-> -> -> -> titleEn|英文标题|string||| |-> -> -> contentList|标题下面的文字描述列表|list of object|√|| |-> -> -> -> detail|内容中文详情|string|√|| |-> -> -> -> detailEn|内容英文详情|string|√|| |-> -> -> -> url|超链接|string||| |-> legal|法务信息|object||| |-> -> country|国家code|string|√|| |-> -> businessList|业务线列表|list of object||| |-> -> -> id|id|integer||新增时不传,修改时传递| |-> -> -> business|业务线code|string|√|| |-> -> -> businessName|业务线中文名称|string|√|| |-> -> -> businessNameEn|业务线英文名称|string|√|| |-> -> -> content|业务线json,具体如下|object|√|| |-> -> -> -> detailEn|具体的详情英文内容|string|√|| |-> -> -> -> detail|具体的详情内容|string|√|| \ No newline at end of file diff --git a/tests/data/embedding/3.knowledge.md b/tests/data/embedding/3.knowledge.md new file mode 100644 index 000000000..61de5f4b8 --- /dev/null +++ b/tests/data/embedding/3.knowledge.md @@ -0,0 +1,189 @@ +## Interfaces +- 用户登录 + - Description: 用户从小程序/微应用发起请求,需要验证用户的合法身份才能正常处理。 + - ID: 1 + - HTTP METHOD: GET + - Endpoint: `/sup/login.json` + - Input Parameters: + |名称|描述|类型(长度)|必选|备注| + | :- | :- | :-: | :- | :- | + |authCode|用户临时免登授权码|String(64)|√|| + |loginTypeEnum|登录类型|String(20)|√|| + |authCorpId|用户所在企业/组织id|String(64)||微应用免登时传递| + |app|应用标识|String(3)|√|| + - Returns: + |名称|描述|类型(长度)|必选|备注| + | :- | :- | :-: | :- | :- | + |success|业务处理成功与否,成功true,否则false|boolean|√|只判断这个属性即可| + |message|错误信息,可以用来提示|string|√|| + |code|返回状态码|string|√|| + |data|用户的sessionId|string|√|| +- 根据sessionId查询用户详细信息 + - Description: 查询当前用户的详细信息,如 staffId,unionId,name,avatar等信息 + - ID: 2 + - HTTP METHOD: GET + - Endpoint: `/sup/user.json` + - Input Parameters: + |名称|描述|类型(长度)|必选|备注| + | :- | :- | :-: | :- | :- | + |NDA_SESSION|用户sessionId|String(64)|√|| + - Returns: + |名称|描述|类型(长度)|必选|备注| + | :- | :- | :-: | :- | :- | + |success|业务处理成功与否,成功true,否则false|boolean|√|只判断这个属性即可| + |message|错误信息,可以用来提示|string|√|| + |code|返回状态码|string|√|| + |data|用户的详细信息|object|√|| + |-> corpId|当前用户企业 钉钉ID(小程序端会拿不到该信息)|string|√|| + |-> corpName|当前用户企业名称(小程序端会拿不到该信息)|string|√|| + |-> staffId|员工在当前企业内的唯一标识,也称staffId(小程序端会拿不到该信息)|string|√|| + |-> unionId|员工在当前开发者企业账号范围内的唯一标识,系统生成,固定值,不会改变。|string|√|| + |-> name|当前用户的名称(小程序端会拿不到该信息)|string|√|| + |-> avatar|头像图片URL|string|√|| +- 查询国家情况描述 + - Description: 根据国家code查询国家情况描述 + - ID: 3 + - HTTP METHOD: GET + - Endpoint: `/sup/country/detail.json` + - Input Parameters: + |名称|描述|类型(长度)|必选|备注| + | :- | :- | :-: | :- | :- | + |countryCode|国家code|string|√|| + - Returns: + |名称|描述|类型(长度)|必选|备注| + | :- | :- | :-: | :- | :- | + |success|业务处理成功true,否则false|boolean|√|只判断这个属性即可| + |message|错误信息,可以用来提示|string|√|| + |code|返回状态码|string|√|| + |data|国家情况描述|object|√|| + |-> id|id|integer|√|| + |-> countryName|国家名称|string|√|| + |-> countryCode|国家code|string|√|| + |-> detail|产品法规分析|string|√|| +- 查询产品法规分析(法律意见详情) + - Description: 根据国家和业务线查询产品法规分析 + - ID: 4 + - HTTP METHOD: GET + - Endpoint: `/sup/legal/detail.json` + - Input Parameters: + |名称|描述|类型(长度)|必选|备注| + | :- | :- | :-: | :- | :- | + |countryCode|国家code|string|√|| + |businessCode|业务线code|string|√|| + - Returns: + |名称|描述|类型(长度)|必选|备注| + | :- | :- | :-: | :- | :- | + |success|业务处理成功true,否则false|boolean|√|只判断这个属性即可| + |message|错误信息,可以用来提示|string|√|| + |code|返回状态码|string|√|| + |data|法律意见详情|object|√|| + |-> id|id|integer|√|| + |-> countryName|国家名称|string|√|| + |-> countryCode|国家code|string|√|| + |-> businessLine|业务线|string|√|| + |-> businessCode|业务线code|string|√|| + |-> detail|产品法规分析|string|√|| + |-> signEntity|签约主体|string|√|| +- 查询法律意见总数 + - Description: 法律意见总数查询 + - ID: 5 + - HTTP METHOD: GET + - Endpoint: `/sup/legal/count.json` + - Input Parameters: + |名称|描述|类型(长度)|必选|备注| + | :- | :- | :-: | :- | :- | + - Returns: + |名称|描述|类型(长度)|必选|备注| + | :- | :- | :-: | :- | :- | + |success|业务处理成功true,否则false|boolean|√|只判断这个属性即可| + |message|错误信息,可以用来提示|string|√|| + |code|返回状态码|string|√|| + |data|总数|integer|√|| +- 查询所有国家和业务线信息列表 + - Description: 查询所有国家和业务线信息列表 + - ID: 6 + - HTTP METHOD: GET + - Endpoint: `/sup/legal/country/list.json` + - Input Parameters: + |名称|描述|类型(长度)|必选|备注| + | :- | :- | :-: | :- | :- | + - Returns: + |名称|描述|类型(长度)|必选|备注| + | :- | :- | :-: | :- | :- | + |success|业务处理成功true,否则false|boolean|√|只判断这个属性即可| + |message|错误信息,可以用来提示|string|√|| + |code|返回状态码|string|√|| + |data|所有数据列表|list of object|√|| + |-> country|国家code|string|√|| + |-> business|业务线code|string|√|| + |-> dataType|数据类型|string|√|| + |-> businessName|业务线名|string|√|| + |-> countryName|国家名|string|√|| + |-> businessNameEn|业务线名(英文)|string|√|| +- 调用法务中台antlaw接口 + - ID: 7 +- 国家/区域导游详情 & 法律意见详情 查询 + - Description:根据国家code查询国家/区域导游信息详情 + - ID: 8 + - HTTP METHOD: GET + - Endpoint: `/contract/country/navigate.json` + - Input Parameters: + |名称|描述|类型(长度)|必选|备注| + | :- | :- | :-: | :- | :- | + |countryCode|国家code|string|√|| + - Returns: + |名称|描述|类型(长度)|必选|备注| + | :- | :- | :-: | :- | :- | + |success|业务处理成功true,否则false|boolean|√|只判断这个属性即可| + |message|错误信息,可以用来提示|string|√|| + |code|返回状态码|string|√|| + |data|国家/区域导游详情|object|√|| + |-> country||||| + |-> -> id|id|integer|√|| + |-> -> country|国家code|string|√|| + |-> -> countryName|国家中文名称|string|√|| + |-> -> countryNameEn|国家英文名称|string|√|| + |-> -> content|国家导游中文详情json数组,具体格式见下示例|list of object|√|| + |-> -> -> title|标题|object|√|| + |-> -> -> -> title|中文标题|string||| + |-> -> -> -> titleEn|英文标题|string||| + |-> -> -> contentList|标题下面的文字描述列表|list of object|√|| + |-> -> -> -> detail|内容中文详情|string|√|| + |-> -> -> -> detailEn|内容英文详情|string|√|| + |-> -> -> -> url|超链接|string||| + |-> legal|法务信息|object||| + |-> -> country|国家code|string|√|| + |-> -> businessList|业务线列表|list of object||| + |-> -> -> id|id|integer||新增时不传,修改时传递| + |-> -> -> business|业务线code|string|√|| + |-> -> -> businessName|业务线中文名称|string|√|| + |-> -> -> businessNameEn|业务线英文名称|string|√|| + |-> -> -> content|业务线json,具体如下|object|√|| + |-> -> -> -> detailEn|具体的详情英文内容|string|√|| + |-> -> -> -> detail|具体的详情内容|string|√|| +- 国家/区域导游列表分页查询 + - Description: 分页查询国家/区域列表 + - ID: 9 + - HTTP METHOD: GET + - Endpoint: `/contract/country/list.json` + - Input Parameters: + |名称|描述|类型(长度)|必选|备注| + | :- | :- | :-: | :- | :- | + |pageSize|分页大小|integer|√|>=1| + |pageNum|分页大小|integer|√|>=1| + |country|国家code|string||| + |business|业务线code|string||| + - Returns: + |名称|描述|类型(长度)|必选|备注| + | :- | :- | :-: | :- | :- | + |success|业务处理成功true,否则false|boolean|√|只判断这个属性即可| + |message|错误信息,可以用来提示|string|√|| + |code|返回状态码|string|√|| + |data|国家/区域导游详情|list of object|√|| + |-> id|id|integer|√|| + |-> country|国家code|string|√|| + |-> countryName|国家中文名称|string|√|| + |-> countryNameEn|国家英文名称|string|√|| + |-> gmtCreate|创建时间|string|√|| + |-> gmtModified|更新时间|string|√|| + |total|数据总量|integer|√|| diff --git a/tests/data/embedding/3.query.md b/tests/data/embedding/3.query.md new file mode 100644 index 000000000..6026899d7 --- /dev/null +++ b/tests/data/embedding/3.query.md @@ -0,0 +1 @@ +根据国家code查询国家业务线列表 \ No newline at end of file diff --git a/tests/metagpt/rag/__init__.py b/tests/metagpt/rag/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/metagpt/rag/test_large_pdf.py b/tests/metagpt/rag/test_large_pdf.py new file mode 100644 index 000000000..4f343aa87 --- /dev/null +++ b/tests/metagpt/rag/test_large_pdf.py @@ -0,0 +1,55 @@ +import pytest + +from metagpt.config2 import Config +from metagpt.const import TEST_DATA_PATH +from metagpt.rag.engines import SimpleEngine +from metagpt.rag.factories.embedding import RAGEmbeddingFactory +from metagpt.utils.common import aread + + +@pytest.mark.skip +@pytest.mark.parametrize( + ("knowledge_filename", "query_filename", "answer_filename"), + [ + ( + TEST_DATA_PATH / "embedding/2.knowledge.md", + TEST_DATA_PATH / "embedding/2.query.md", + TEST_DATA_PATH / "embedding/2.answer.md", + ), + ( + TEST_DATA_PATH / "embedding/3.knowledge.md", + TEST_DATA_PATH / "embedding/3.query.md", + TEST_DATA_PATH / "embedding/3.answer.md", + ), + ], +) +@pytest.mark.asyncio +async def test_large_pdf(knowledge_filename, query_filename, answer_filename): + Config.default(reload=True) # `config.embedding.model = "text-embedding-ada-002"` changes the cache. + + engine = SimpleEngine.from_docs( + input_files=[knowledge_filename], + ) + + query = await aread(filename=query_filename) + rsp = await engine.aretrieve(query) + assert rsp + + config = Config.default() + config.embedding.model = "text-embedding-ada-002" + factory = RAGEmbeddingFactory(config) + embedding = factory.get_rag_embedding() + answer = await aread(filename=answer_filename) + answer_embedding = await embedding.aget_text_embedding(answer) + similarity = 0 + for i in rsp: + rsp_embedding = await embedding.aget_query_embedding(i.text) + v = embedding.similarity(answer_embedding, rsp_embedding) + similarity = max(similarity, v) + + print(similarity) + assert similarity > 0.9 + + +if __name__ == "__main__": + pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/tools/libs/test_index_repo.py b/tests/metagpt/tools/libs/test_index_repo.py new file mode 100644 index 000000000..3cc8ad406 --- /dev/null +++ b/tests/metagpt/tools/libs/test_index_repo.py @@ -0,0 +1,32 @@ +import shutil + +import pytest + +from metagpt.const import DEFAULT_WORKSPACE_ROOT, TEST_DATA_PATH +from metagpt.tools.libs.index_repo import IndexRepo + + +@pytest.mark.asyncio +@pytest.mark.parametrize(("path", "query"), [(TEST_DATA_PATH / "requirements", "业务线")]) +async def test_index_repo(path, query): + index_path = DEFAULT_WORKSPACE_ROOT / ".index" + repo = IndexRepo(persist_path=str(index_path), root_path=str(path), min_token_count=0) + await repo.add([path]) + await repo.add([path]) + assert index_path.exists() + + rsp = await repo.search(query) + assert rsp + + repo2 = IndexRepo(persist_path=str(index_path), root_path=str(path), min_token_count=0) + rsp2 = await repo2.search(query) + assert rsp2 + + merged_rsp = await repo.merge(query=query, indices_list=[rsp, rsp2]) + assert merged_rsp + + shutil.rmtree(index_path) + + +if __name__ == "__main__": + pytest.main([__file__, "-s"])