diff --git a/metagpt/base/base_role.py b/metagpt/base/base_role.py index b500b2cd6..1f7f00fa2 100644 --- a/metagpt/base/base_role.py +++ b/metagpt/base/base_role.py @@ -2,7 +2,6 @@ from abc import abstractmethod from typing import Optional, Union from metagpt.base.base_serialization import BaseSerialization -from metagpt.schema import Message class BaseRole(BaseSerialization): @@ -25,13 +24,13 @@ class BaseRole(BaseSerialization): raise NotImplementedError @abstractmethod - async def react(self) -> Message: + async def react(self) -> "Message": """Entry to one of three strategies by which Role reacts to the observed Message.""" @abstractmethod - async def run(self, with_message: Optional[Union[str, Message, list[str]]] = None) -> Optional[Message]: + async def run(self, with_message: Optional[Union[str, "Message", list[str]]] = None) -> Optional["Message"]: """Observe, and think and act based on the results of the observation.""" @abstractmethod - def get_memories(self, k: int = 0) -> list[Message]: + def get_memories(self, k: int = 0) -> list["Message"]: """Return the most recent k memories of this role.""" diff --git a/metagpt/configs/embedding_config.py b/metagpt/configs/embedding_config.py index 20de47999..f9b41b9dc 100644 --- a/metagpt/configs/embedding_config.py +++ b/metagpt/configs/embedding_config.py @@ -20,11 +20,13 @@ class EmbeddingConfig(YamlModel): --------- api_type: "openai" api_key: "YOU_API_KEY" + dimensions: "YOUR_MODEL_DIMENSIONS" api_type: "azure" api_key: "YOU_API_KEY" base_url: "YOU_BASE_URL" api_version: "YOU_API_VERSION" + dimensions: "YOUR_MODEL_DIMENSIONS" api_type: "gemini" api_key: "YOU_API_KEY" @@ -32,6 +34,7 @@ class EmbeddingConfig(YamlModel): api_type: "ollama" base_url: "YOU_BASE_URL" model: "YOU_MODEL" + dimensions: "YOUR_MODEL_DIMENSIONS" """ api_type: Optional[EmbeddingType] = None @@ -41,6 +44,7 @@ class EmbeddingConfig(YamlModel): model: Optional[str] = None embed_batch_size: Optional[int] = None + dimensions: Optional[int] = None # output dimension of embedding model @field_validator("api_type", mode="before") @classmethod diff --git a/metagpt/provider/base_llm.py b/metagpt/provider/base_llm.py index 75d8bfe00..f9111ffe0 100644 --- a/metagpt/provider/base_llm.py +++ b/metagpt/provider/base_llm.py @@ -27,7 +27,6 @@ from metagpt.configs.llm_config import LLMConfig from metagpt.const import IMAGES, LLM_API_TIMEOUT, USE_CONFIG_TIMEOUT from metagpt.logs import logger from metagpt.provider.constant import MULTI_MODAL_MODELS -from metagpt.schema import Message from metagpt.utils.common import log_and_reraise from metagpt.utils.cost_manager import CostManager, Costs from metagpt.utils.token_counter import TOKEN_MAX @@ -80,7 +79,7 @@ class BaseLLM(ABC): def support_image_input(self) -> bool: return any([m in self.config.model for m in MULTI_MODAL_MODELS]) - def format_msg(self, messages: Union[str, Message, list[dict], list[Message], list[str]]) -> list[dict]: + def format_msg(self, messages: Union[str, "Message", list[dict], list["Message"], list[str]]) -> list[dict]: """convert messages to list[dict].""" from metagpt.schema import Message @@ -173,7 +172,9 @@ class BaseLLM(ABC): context.append(self._assistant_msg(rsp_text)) return self._extract_assistant_rsp(context) - async def aask_code(self, messages: Union[str, Message, list[dict]], timeout=USE_CONFIG_TIMEOUT, **kwargs) -> dict: + async def aask_code( + self, messages: Union[str, "Message", list[dict]], timeout=USE_CONFIG_TIMEOUT, **kwargs + ) -> dict: raise NotImplementedError @abstractmethod diff --git a/metagpt/provider/google_gemini_api.py b/metagpt/provider/google_gemini_api.py index e4b3a3f17..5c1b92503 100644 --- a/metagpt/provider/google_gemini_api.py +++ b/metagpt/provider/google_gemini_api.py @@ -22,7 +22,6 @@ from metagpt.const import USE_CONFIG_TIMEOUT from metagpt.logs import log_llm_stream, logger from metagpt.provider.base_llm import BaseLLM from metagpt.provider.llm_provider_registry import register_provider -from metagpt.schema import Message class GeminiGenerativeModel(GenerativeModel): @@ -73,7 +72,7 @@ class GeminiLLM(BaseLLM): def _system_msg(self, msg: str) -> dict[str, str]: return {"role": "user", "parts": [msg]} - def format_msg(self, messages: Union[str, Message, list[dict], list[Message], list[str]]) -> list[dict]: + def format_msg(self, messages: Union[str, "Message", list[dict], list["Message"], list[str]]) -> list[dict]: """convert messages to list[dict].""" from metagpt.schema import Message diff --git a/metagpt/rag/engines/simple.py b/metagpt/rag/engines/simple.py index e48decdab..8d78fcad7 100644 --- a/metagpt/rag/engines/simple.py +++ b/metagpt/rag/engines/simple.py @@ -2,7 +2,8 @@ import json import os -from typing import Any, Optional, Union +from pathlib import Path +from typing import Any, List, Optional, Set, Union import fsspec from llama_index.core import SimpleDirectoryReader @@ -78,6 +79,7 @@ class SimpleEngine(RetrieverQueryEngine): callback_manager=callback_manager, ) self._transformations = transformations or self._default_transformations() + self._filenames = set() @classmethod def from_docs( @@ -192,11 +194,11 @@ class SimpleEngine(RetrieverQueryEngine): self._try_reconstruct_obj(nodes) return nodes - def add_docs(self, input_files: list[str]): + def add_docs(self, input_files: List[Union[str, Path]]): """Add docs to retriever. retriever must has add_nodes func.""" self._ensure_retriever_modifiable() - documents = SimpleDirectoryReader(input_files=input_files).load_data() + documents = SimpleDirectoryReader(input_files=[str(i) for i in input_files]).load_data() self._fix_document_metadata(documents) nodes = run_transformations(documents, transformations=self._transformations) @@ -227,6 +229,24 @@ class SimpleEngine(RetrieverQueryEngine): return self.retriever.clear(**kwargs) + def delete_docs(self, input_files: List[Union[str, Path]]): + """Delete documents from the index and document store. + + Args: + input_files (List[Union[str, Path]]): A list of file paths or file names to be deleted. + + Raises: + NotImplementedError: If the method is not implemented. + """ + exists_filenames = set() + filenames = {str(i) for i in input_files} + for doc_id, info in self.retriever._index.ref_doc_info.items(): + if info.metadata.get("file_path") in filenames: + exists_filenames.add(doc_id) + + for doc_id in exists_filenames: + self.retriever._index.delete_ref_doc(doc_id, delete_from_docstore=True) + @staticmethod def get_obj_nodes(objs: Optional[list[RAGObject]] = None) -> list[ObjectNode]: """Converts a list of RAGObjects to a list of ObjectNodes.""" @@ -333,3 +353,7 @@ class SimpleEngine(RetrieverQueryEngine): @staticmethod def _default_transformations(): return [SentenceSplitter()] + + @property + def filenames(self) -> Set[str]: + return self._filenames diff --git a/metagpt/rag/factories/embedding.py b/metagpt/rag/factories/embedding.py index d647883bd..19b8b36f6 100644 --- a/metagpt/rag/factories/embedding.py +++ b/metagpt/rag/factories/embedding.py @@ -5,9 +5,6 @@ from typing import Any, Optional from llama_index.core.embeddings import BaseEmbedding from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding -from llama_index.embeddings.gemini import GeminiEmbedding -from llama_index.embeddings.ollama import OllamaEmbedding -from llama_index.embeddings.openai import OpenAIEmbedding from metagpt.config2 import Config from metagpt.configs.embedding_config import EmbeddingType @@ -49,7 +46,9 @@ class RAGEmbeddingFactory(GenericFactory): raise TypeError("To use RAG, please set your embedding in config2.yaml.") - def _create_openai(self) -> OpenAIEmbedding: + def _create_openai(self) -> "OpenAIEmbedding": + from llama_index.embeddings.openai import OpenAIEmbedding + params = dict( api_key=self.config.embedding.api_key or self.config.llm.api_key, api_base=self.config.embedding.base_url or self.config.llm.base_url, @@ -70,7 +69,9 @@ class RAGEmbeddingFactory(GenericFactory): return AzureOpenAIEmbedding(**params) - def _create_gemini(self) -> GeminiEmbedding: + def _create_gemini(self) -> "GeminiEmbedding": + from llama_index.embeddings.gemini import GeminiEmbedding + params = dict( api_key=self.config.embedding.api_key, api_base=self.config.embedding.base_url, @@ -80,7 +81,9 @@ class RAGEmbeddingFactory(GenericFactory): return GeminiEmbedding(**params) - def _create_ollama(self) -> OllamaEmbedding: + def _create_ollama(self) -> "OllamaEmbedding": + from llama_index.embeddings.ollama import OllamaEmbedding + params = dict( base_url=self.config.embedding.base_url, ) diff --git a/metagpt/rag/factories/llm.py b/metagpt/rag/factories/llm.py index 59f6db4d9..bd252771a 100644 --- a/metagpt/rag/factories/llm.py +++ b/metagpt/rag/factories/llm.py @@ -13,7 +13,6 @@ from llama_index.core.llms.callbacks import llm_completion_callback from pydantic import Field from metagpt.config2 import Config -from metagpt.llm import LLM from metagpt.provider.base_llm import BaseLLM from metagpt.utils.async_helper import NestAsyncio from metagpt.utils.token_counter import TOKEN_MAX @@ -79,4 +78,6 @@ class RAGLLM(CustomLLM): def get_rag_llm(model_infer: BaseLLM = None) -> RAGLLM: """Get llm that can be used by LlamaIndex.""" + from metagpt.llm import LLM + return RAGLLM(model_infer=model_infer or LLM()) diff --git a/metagpt/tools/libs/index_repo.py b/metagpt/tools/libs/index_repo.py new file mode 100644 index 000000000..fadc11522 --- /dev/null +++ b/metagpt/tools/libs/index_repo.py @@ -0,0 +1,264 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import json +from pathlib import Path +from typing import Dict, List, Optional, Set, Union + +import tiktoken +from llama_index.core.base.embeddings.base import BaseEmbedding +from llama_index.core.schema import NodeWithScore +from pydantic import BaseModel, Field, model_validator + +from metagpt.config2 import Config +from metagpt.logs import logger +from metagpt.rag.engines import SimpleEngine +from metagpt.rag.factories.embedding import RAGEmbeddingFactory +from metagpt.rag.schema import FAISSIndexConfig, FAISSRetrieverConfig, LLMRankerConfig +from metagpt.utils.common import aread, awrite, generate_fingerprint, list_files +from metagpt.utils.repo_to_markdown import is_text_file + + +class TextScore(BaseModel): + filename: str + text: str + score: Optional[float] = None + + +class IndexRepo(BaseModel): + persist_path: str # The persist path of the index repo, {DEFAULT_WORKSPACE_ROOT}/.index/{chat_id or 'uploads'}/ + root_path: str # `/data/uploads` or r`/data/chats/\d+`, the root path of files indexed by the index repo. + fingerprint_filename: str = "fingerprint.json" + model: Optional[str] = None + min_token_count: int = 10000 + max_token_count: int = 100000000 + recall_count: int = 5 + embedding: Optional[BaseEmbedding] = Field(default=None, exclude=True) + fingerprints: Dict[str, str] = Field(default_factory=dict) + + @model_validator(mode="after") + def _update_fingerprints(self) -> "IndexRepo": + """Load fingerprints from the fingerprint file if not already loaded. + + Returns: + IndexRepo: The updated IndexRepo instance. + """ + if not self.fingerprints: + filename = Path(self.persist_path) / self.fingerprint_filename + if not filename.exists(): + return self + with open(str(filename), "r") as reader: + self.fingerprints = json.load(reader) + return self + + async def search( + self, query: str, filenames: Optional[List[Path]] = None + ) -> Optional[List[Union[NodeWithScore, TextScore]]]: + """Search for documents related to the given query. + + Args: + query (str): The search query. + filenames (Optional[List[Path]]): A list of filenames to filter the search. + + Returns: + Optional[List[Union[NodeWithScore, TextScore]]]: A list of search results containing NodeWithScore or TextScore. + """ + encoding = tiktoken.get_encoding("cl100k_base") + result: List[Union[NodeWithScore, TextScore]] = [] + filenames, _ = await self._filter(filenames) + filter_filenames = set() + for i in filenames: + content = await aread(filename=i) + token_count = len(encoding.encode(content)) + if not self._is_buildable(token_count): + result.append(TextScore(filename=str(i), text=content)) + continue + file_fingerprint = generate_fingerprint(content) + if self.fingerprints.get(str(i)) != file_fingerprint: + logger.error(f'file: "{i}" changed but not indexed') + continue + filter_filenames.add(str(i)) + nodes = await self._search(query=query, filters=filter_filenames) + return result + nodes + + async def merge( + self, query: str, indices_list: List[List[Union[NodeWithScore, TextScore]]] + ) -> List[Union[NodeWithScore, TextScore]]: + """Merge results from multiple indices based on the query. + + Args: + query (str): The search query. + indices_list (List[List[Union[NodeWithScore, TextScore]]]): A list of result lists from different indices. + + Returns: + List[Union[NodeWithScore, TextScore]]: A list of merged results sorted by similarity. + """ + if not self.embedding: + config = Config.default() + if self.model: + config.embedding.model = self.model + factory = RAGEmbeddingFactory(config) + self.embedding = factory.get_rag_embedding() + + scores = [] + query_embedding = await self.embedding.aget_text_embedding(query) + flat_nodes = [node for indices in indices_list for node in indices] + for i in flat_nodes: + text_embedding = await self.embedding.aget_text_embedding(i.text) + similarity = self.embedding.similarity(query_embedding, text_embedding) + scores.append((similarity, i)) + scores.sort(key=lambda x: x[0], reverse=True) + return [i[1] for i in scores][: self.recall_count] + + async def add(self, paths: List[Path]): + """Add new documents to the index. + + Args: + paths (List[Path]): A list of paths to the documents to be added. + """ + encoding = tiktoken.get_encoding("cl100k_base") + filenames, _ = await self._filter(paths) + filter_filenames = [] + delete_filenames = [] + for i in filenames: + content = await aread(filename=i) + if not self._is_fingerprint_changed(filename=i, content=content): + continue + token_count = len(encoding.encode(content)) + if self._is_buildable(token_count): + filter_filenames.append(i) + logger.debug(f"{i} is_buildable: {token_count}, {self.min_token_count}~{self.max_token_count}") + else: + delete_filenames.append(i) + logger.debug(f"{i} not is_buildable: {token_count}, {self.min_token_count}~{self.max_token_count}") + await self._add_batch(filenames=filter_filenames, delete_filenames=delete_filenames) + + async def _add_batch(self, filenames: List[Union[str, Path]], delete_filenames: List[Union[str, Path]]): + """Add and remove documents in a batch operation. + + Args: + filenames (List[Union[str, Path]]): List of filenames to add. + delete_filenames (List[Union[str, Path]]): List of filenames to delete. + """ + if not filenames: + return + logger.info(f"update index repo, add {filenames}, remove {delete_filenames}") + engine = None + if Path(self.persist_path).exists(): + logger.debug(f"load index from {self.persist_path}") + engine = SimpleEngine.from_index( + index_config=FAISSIndexConfig(persist_path=self.persist_path), + retriever_configs=[FAISSRetrieverConfig()], + ) + try: + engine.delete_docs(filenames + delete_filenames) + logger.debug(f"delete docs {filenames + delete_filenames}") + engine.add_docs(input_files=filenames) + logger.debug(f"add docs {filenames}") + except NotImplementedError as e: + logger.debug(f"{e}") + filenames = list(set([str(i) for i in filenames] + list(self.fingerprints.keys()))) + engine = None + logger.info(f"{e}. Rebuild all.") + if not engine: + engine = SimpleEngine.from_docs( + input_files=[str(i) for i in filenames], + retriever_configs=[FAISSRetrieverConfig()], + ranker_configs=[LLMRankerConfig()], + ) + logger.debug(f"add docs {filenames}") + engine.persist(persist_dir=self.persist_path) + for i in filenames: + content = await aread(i) + fp = generate_fingerprint(content) + self.fingerprints[str(i)] = fp + await awrite(filename=Path(self.persist_path) / self.fingerprint_filename, data=json.dumps(self.fingerprints)) + + def __str__(self): + """Return a string representation of the IndexRepo. + + Returns: + str: The filename of the index repository. + """ + return f"{self.persist_path}" + + def _is_buildable(self, token_count: int) -> bool: + """Check if the token count is within the buildable range. + + Args: + token_count (int): The number of tokens in the content. + + Returns: + bool: True if buildable, False otherwise. + """ + if token_count < self.min_token_count or token_count > self.max_token_count: + return False + return True + + async def _filter(self, filenames: Optional[List[Union[str, Path]]] = None) -> (List[Path], List[Path]): + """Filter the provided filenames to only include valid text files. + + Args: + filenames (Optional[List[Union[str, Path]]]): List of filenames to filter. + + Returns: + Tuple[List[Path], List[Path]]: A tuple containing a list of valid pathnames and a list of excluded paths. + """ + root_path = Path(self.root_path).absolute() + if not filenames: + filenames = [root_path] + pathnames = [] + excludes = [] + for i in filenames: + path = Path(i).absolute() + if not path.is_relative_to(root_path): + excludes.append(path) + logger.debug(f"{path} not is_relative_to {root_path})") + continue + if not path.is_dir(): + is_text, _ = await is_text_file(path) + if is_text: + pathnames.append(path) + continue + subfiles = list_files(path) + for j in subfiles: + is_text, _ = await is_text_file(j) + if is_text: + pathnames.append(j) + + logger.debug(f"{pathnames}, excludes:{excludes})") + return pathnames, excludes + + async def _search(self, query: str, filters: Set[str]) -> List[NodeWithScore]: + """Perform a search for the given query using the index. + + Args: + query (str): The search query. + filters (Set[str]): A set of filenames to filter the search results. + + Returns: + List[NodeWithScore]: A list of nodes with scores matching the query. + """ + if not Path(self.persist_path).exists(): + return [] + engine = SimpleEngine.from_index( + index_config=FAISSIndexConfig(persist_path=self.persist_path), retriever_configs=[FAISSRetrieverConfig()] + ) + rsp = await engine.aretrieve(query) + return [i for i in rsp if i.metadata.get("file_path") in filters] + + def _is_fingerprint_changed(self, filename: Union[str, Path], content: str) -> bool: + """Check if the fingerprint of the given document content has changed. + + Args: + filename (Union[str, Path]): The filename of the document. + content (str): The content of the document. + + Returns: + bool: True if the fingerprint has changed, False otherwise. + """ + old_fp = self.fingerprints.get(str(filename)) + if not old_fp: + return True + fp = generate_fingerprint(content) + return old_fp != fp diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py index 42a872c76..90f13da23 100644 --- a/metagpt/utils/common.py +++ b/metagpt/utils/common.py @@ -16,6 +16,7 @@ import base64 import contextlib import csv import functools +import hashlib import importlib import inspect import json @@ -889,7 +890,7 @@ async def get_mime_type(filename: str | Path, force_read: bool = False) -> str: } try: - stdout, stderr, _ = await shell_execute(f"file --mime-type {str(filename)}") + stdout, stderr, _ = await shell_execute(f"file --mime-type '{str(filename)}'") if stderr: logger.debug(f"file:{filename}, error:{stderr}") return guess_mime_type @@ -1175,3 +1176,23 @@ def rectify_pathname(path: Union[str, Path], default_filename: str) -> Path: else: output_pathname.parent.mkdir(parents=True, exist_ok=True) return output_pathname + + +def generate_fingerprint(text: str) -> str: + """ + Generate a fingerprint for the given text + + Args: + text (str): The text for which the fingerprint needs to be generated + + Returns: + str: The fingerprint value of the text + """ + text_bytes = text.encode("utf-8") + + # calculate SHA-256 hash + sha256 = hashlib.sha256() + sha256.update(text_bytes) + fingerprint = sha256.hexdigest() + + return fingerprint diff --git a/tests/data/embedding/2.answer.md b/tests/data/embedding/2.answer.md new file mode 100644 index 000000000..3807f03c1 --- /dev/null +++ b/tests/data/embedding/2.answer.md @@ -0,0 +1,2 @@ +检索结果 +法务查询者可根据国际小超人钉钉小程序UI上的滚筒切换业务线 这张图片展示了一个移动应用的界面,界面标题为“法律意见详情”。用户可以根据具体情况切换业务线。界面中有多个字段,包括“国家名称”、“国家情况描述”、“业务线”、“产品法规分析”和“签约主体”。第一张截图显示了详细的法律情报信息,包含区域名称、区域情况描述、业务线和产品法规概述等字段。第二张截图显示了“法律意见详情”界面,其中列出了国家名称、国家情况描述、业务线、产品法规分析和签约主体。第三张截图与第二张相似,但显示了选项的可选择状态。最下方有“取消”和“确定”的按钮。 法务查询者从国家详情中的业务线名列表中选出要查看的业务线。 \ No newline at end of file diff --git a/tests/data/embedding/2.knowledge.md b/tests/data/embedding/2.knowledge.md new file mode 100644 index 000000000..615614098 --- /dev/null +++ b/tests/data/embedding/2.knowledge.md @@ -0,0 +1,25 @@ +## Textual User Requirements + +### 3.2. 首页 + +首页有两个分区,上面部分是法律意见检索栏。 + +法务查询者第一次进入国际小超人钉钉小程序展示引导页,以后进入不再展示,点击「我知道了」引导页消失。 + +#### 首页 +![首页](1.png) +这是一个名为“法务小超人”的移动应用程序的界面截图。界面顶部显示了应用名称和一个可切换语言的按钮“English”。在界面中间部分,有一个标题“法律意见查询”,以及一个搜索框,提示输入国家名称以查询法律意见。下方显示已收录法律意见8394篇。界面下半部分是“法务 Q&A”部分,列出了一些法律相关的选项,例如“国际法务接入口人”、“国内法务接入口人”、“国际法律协议合同办理指引”和“国内法律协议合同办理指引”。界面底部有三个导航按钮,分别是“首页”、“模板”和“我的”。 + +#### 按国家名维度搜索 +法务查询者在国际小超人钉钉小程序的搜索框中进行检索时采用typeahead,只能下拉选择法务中台中有的国家名称。 +![按国家名维度搜索](2.png) +在这张图像中,用户正在一个名为“法律意见查询”的应用中进行国家名称的搜索。用户在搜索框中输入国家名称时,系统会提供下拉建议。这些建议基于 typeahead 功能,从法务中台中筛选出匹配的国家名称供用户选择。目前,搜索结果包含了“中国”和“菲律宾”两个具体的国家名称,其它显示为“国家名”。用户可以通过下拉菜单快速选择所需的国家名称。 + +#### 检索结果 +法务查询者可根据国际小超人钉钉小程序UI上的滚筒切换业务线 +![检索结果](3.png) +这张图片展示了一个移动应用的界面,界面标题为“法律意见详情”。用户可以根据具体情况切换业务线。界面中有多个字段,包括“国家名称”、“国家情况描述”、“业务线”、“产品法规分析”和“签约主体”。第一张截图显示了详细的法律情报信息,包含区域名称、区域情况描述、业务线和产品法规概述等字段。第二张截图显示了“法律意见详情”界面,其中列出了国家名称、国家情况描述、业务线、产品法规分析和签约主体。第三张截图与第二张相似,但显示了选项的可选择状态。最下方有“取消”和“确定”的按钮。 +法务查询者从国家详情中的业务线名列表中选出要查看的业务线。 + +#### 查看法律意见详情 +国际小超人钉钉小程序用国家代码和业务代码做参数,查询法律意见详情,然后将法律意见详情展示给法务查询者。 \ No newline at end of file diff --git a/tests/data/embedding/2.query.md b/tests/data/embedding/2.query.md new file mode 100644 index 000000000..ba470b8bd --- /dev/null +++ b/tests/data/embedding/2.query.md @@ -0,0 +1 @@ +业务线UI有哪些操作? \ No newline at end of file diff --git a/tests/data/embedding/3.answer.md b/tests/data/embedding/3.answer.md new file mode 100644 index 000000000..35b0c6899 --- /dev/null +++ b/tests/data/embedding/3.answer.md @@ -0,0 +1,7 @@ +国家/区域导游详情 & 法律意见详情 查询 +Description:根据国家code查询国家/区域导游信息详情 +ID: 8 +HTTP METHOD: GET +Endpoint: /contract/country/navigate.json +Input Parameters: |名称|描述|类型(长度)|必选|备注| | :- | :- | :-: | :- | :- | |countryCode|国家code|string|√|| +Returns: |名称|描述|类型(长度)|必选|备注| | :- | :- | :-: | :- | :- | |success|业务处理成功true,否则false|boolean|√|只判断这个属性即可| |message|错误信息,可以用来提示|string|√|| |code|返回状态码|string|√|| |data|国家/区域导游详情|object|√|| |-> country||||| |-> -> id|id|integer|√|| |-> -> country|国家code|string|√|| |-> -> countryName|国家中文名称|string|√|| |-> -> countryNameEn|国家英文名称|string|√|| |-> -> content|国家导游中文详情json数组,具体格式见下示例|list of object|√|| |-> -> -> title|标题|object|√|| |-> -> -> -> title|中文标题|string||| |-> -> -> -> titleEn|英文标题|string||| |-> -> -> contentList|标题下面的文字描述列表|list of object|√|| |-> -> -> -> detail|内容中文详情|string|√|| |-> -> -> -> detailEn|内容英文详情|string|√|| |-> -> -> -> url|超链接|string||| |-> legal|法务信息|object||| |-> -> country|国家code|string|√|| |-> -> businessList|业务线列表|list of object||| |-> -> -> id|id|integer||新增时不传,修改时传递| |-> -> -> business|业务线code|string|√|| |-> -> -> businessName|业务线中文名称|string|√|| |-> -> -> businessNameEn|业务线英文名称|string|√|| |-> -> -> content|业务线json,具体如下|object|√|| |-> -> -> -> detailEn|具体的详情英文内容|string|√|| |-> -> -> -> detail|具体的详情内容|string|√|| \ No newline at end of file diff --git a/tests/data/embedding/3.knowledge.md b/tests/data/embedding/3.knowledge.md new file mode 100644 index 000000000..61de5f4b8 --- /dev/null +++ b/tests/data/embedding/3.knowledge.md @@ -0,0 +1,189 @@ +## Interfaces +- 用户登录 + - Description: 用户从小程序/微应用发起请求,需要验证用户的合法身份才能正常处理。 + - ID: 1 + - HTTP METHOD: GET + - Endpoint: `/sup/login.json` + - Input Parameters: + |名称|描述|类型(长度)|必选|备注| + | :- | :- | :-: | :- | :- | + |authCode|用户临时免登授权码|String(64)|√|| + |loginTypeEnum|登录类型|String(20)|√|| + |authCorpId|用户所在企业/组织id|String(64)||微应用免登时传递| + |app|应用标识|String(3)|√|| + - Returns: + |名称|描述|类型(长度)|必选|备注| + | :- | :- | :-: | :- | :- | + |success|业务处理成功与否,成功true,否则false|boolean|√|只判断这个属性即可| + |message|错误信息,可以用来提示|string|√|| + |code|返回状态码|string|√|| + |data|用户的sessionId|string|√|| +- 根据sessionId查询用户详细信息 + - Description: 查询当前用户的详细信息,如 staffId,unionId,name,avatar等信息 + - ID: 2 + - HTTP METHOD: GET + - Endpoint: `/sup/user.json` + - Input Parameters: + |名称|描述|类型(长度)|必选|备注| + | :- | :- | :-: | :- | :- | + |NDA_SESSION|用户sessionId|String(64)|√|| + - Returns: + |名称|描述|类型(长度)|必选|备注| + | :- | :- | :-: | :- | :- | + |success|业务处理成功与否,成功true,否则false|boolean|√|只判断这个属性即可| + |message|错误信息,可以用来提示|string|√|| + |code|返回状态码|string|√|| + |data|用户的详细信息|object|√|| + |-> corpId|当前用户企业 钉钉ID(小程序端会拿不到该信息)|string|√|| + |-> corpName|当前用户企业名称(小程序端会拿不到该信息)|string|√|| + |-> staffId|员工在当前企业内的唯一标识,也称staffId(小程序端会拿不到该信息)|string|√|| + |-> unionId|员工在当前开发者企业账号范围内的唯一标识,系统生成,固定值,不会改变。|string|√|| + |-> name|当前用户的名称(小程序端会拿不到该信息)|string|√|| + |-> avatar|头像图片URL|string|√|| +- 查询国家情况描述 + - Description: 根据国家code查询国家情况描述 + - ID: 3 + - HTTP METHOD: GET + - Endpoint: `/sup/country/detail.json` + - Input Parameters: + |名称|描述|类型(长度)|必选|备注| + | :- | :- | :-: | :- | :- | + |countryCode|国家code|string|√|| + - Returns: + |名称|描述|类型(长度)|必选|备注| + | :- | :- | :-: | :- | :- | + |success|业务处理成功true,否则false|boolean|√|只判断这个属性即可| + |message|错误信息,可以用来提示|string|√|| + |code|返回状态码|string|√|| + |data|国家情况描述|object|√|| + |-> id|id|integer|√|| + |-> countryName|国家名称|string|√|| + |-> countryCode|国家code|string|√|| + |-> detail|产品法规分析|string|√|| +- 查询产品法规分析(法律意见详情) + - Description: 根据国家和业务线查询产品法规分析 + - ID: 4 + - HTTP METHOD: GET + - Endpoint: `/sup/legal/detail.json` + - Input Parameters: + |名称|描述|类型(长度)|必选|备注| + | :- | :- | :-: | :- | :- | + |countryCode|国家code|string|√|| + |businessCode|业务线code|string|√|| + - Returns: + |名称|描述|类型(长度)|必选|备注| + | :- | :- | :-: | :- | :- | + |success|业务处理成功true,否则false|boolean|√|只判断这个属性即可| + |message|错误信息,可以用来提示|string|√|| + |code|返回状态码|string|√|| + |data|法律意见详情|object|√|| + |-> id|id|integer|√|| + |-> countryName|国家名称|string|√|| + |-> countryCode|国家code|string|√|| + |-> businessLine|业务线|string|√|| + |-> businessCode|业务线code|string|√|| + |-> detail|产品法规分析|string|√|| + |-> signEntity|签约主体|string|√|| +- 查询法律意见总数 + - Description: 法律意见总数查询 + - ID: 5 + - HTTP METHOD: GET + - Endpoint: `/sup/legal/count.json` + - Input Parameters: + |名称|描述|类型(长度)|必选|备注| + | :- | :- | :-: | :- | :- | + - Returns: + |名称|描述|类型(长度)|必选|备注| + | :- | :- | :-: | :- | :- | + |success|业务处理成功true,否则false|boolean|√|只判断这个属性即可| + |message|错误信息,可以用来提示|string|√|| + |code|返回状态码|string|√|| + |data|总数|integer|√|| +- 查询所有国家和业务线信息列表 + - Description: 查询所有国家和业务线信息列表 + - ID: 6 + - HTTP METHOD: GET + - Endpoint: `/sup/legal/country/list.json` + - Input Parameters: + |名称|描述|类型(长度)|必选|备注| + | :- | :- | :-: | :- | :- | + - Returns: + |名称|描述|类型(长度)|必选|备注| + | :- | :- | :-: | :- | :- | + |success|业务处理成功true,否则false|boolean|√|只判断这个属性即可| + |message|错误信息,可以用来提示|string|√|| + |code|返回状态码|string|√|| + |data|所有数据列表|list of object|√|| + |-> country|国家code|string|√|| + |-> business|业务线code|string|√|| + |-> dataType|数据类型|string|√|| + |-> businessName|业务线名|string|√|| + |-> countryName|国家名|string|√|| + |-> businessNameEn|业务线名(英文)|string|√|| +- 调用法务中台antlaw接口 + - ID: 7 +- 国家/区域导游详情 & 法律意见详情 查询 + - Description:根据国家code查询国家/区域导游信息详情 + - ID: 8 + - HTTP METHOD: GET + - Endpoint: `/contract/country/navigate.json` + - Input Parameters: + |名称|描述|类型(长度)|必选|备注| + | :- | :- | :-: | :- | :- | + |countryCode|国家code|string|√|| + - Returns: + |名称|描述|类型(长度)|必选|备注| + | :- | :- | :-: | :- | :- | + |success|业务处理成功true,否则false|boolean|√|只判断这个属性即可| + |message|错误信息,可以用来提示|string|√|| + |code|返回状态码|string|√|| + |data|国家/区域导游详情|object|√|| + |-> country||||| + |-> -> id|id|integer|√|| + |-> -> country|国家code|string|√|| + |-> -> countryName|国家中文名称|string|√|| + |-> -> countryNameEn|国家英文名称|string|√|| + |-> -> content|国家导游中文详情json数组,具体格式见下示例|list of object|√|| + |-> -> -> title|标题|object|√|| + |-> -> -> -> title|中文标题|string||| + |-> -> -> -> titleEn|英文标题|string||| + |-> -> -> contentList|标题下面的文字描述列表|list of object|√|| + |-> -> -> -> detail|内容中文详情|string|√|| + |-> -> -> -> detailEn|内容英文详情|string|√|| + |-> -> -> -> url|超链接|string||| + |-> legal|法务信息|object||| + |-> -> country|国家code|string|√|| + |-> -> businessList|业务线列表|list of object||| + |-> -> -> id|id|integer||新增时不传,修改时传递| + |-> -> -> business|业务线code|string|√|| + |-> -> -> businessName|业务线中文名称|string|√|| + |-> -> -> businessNameEn|业务线英文名称|string|√|| + |-> -> -> content|业务线json,具体如下|object|√|| + |-> -> -> -> detailEn|具体的详情英文内容|string|√|| + |-> -> -> -> detail|具体的详情内容|string|√|| +- 国家/区域导游列表分页查询 + - Description: 分页查询国家/区域列表 + - ID: 9 + - HTTP METHOD: GET + - Endpoint: `/contract/country/list.json` + - Input Parameters: + |名称|描述|类型(长度)|必选|备注| + | :- | :- | :-: | :- | :- | + |pageSize|分页大小|integer|√|>=1| + |pageNum|分页大小|integer|√|>=1| + |country|国家code|string||| + |business|业务线code|string||| + - Returns: + |名称|描述|类型(长度)|必选|备注| + | :- | :- | :-: | :- | :- | + |success|业务处理成功true,否则false|boolean|√|只判断这个属性即可| + |message|错误信息,可以用来提示|string|√|| + |code|返回状态码|string|√|| + |data|国家/区域导游详情|list of object|√|| + |-> id|id|integer|√|| + |-> country|国家code|string|√|| + |-> countryName|国家中文名称|string|√|| + |-> countryNameEn|国家英文名称|string|√|| + |-> gmtCreate|创建时间|string|√|| + |-> gmtModified|更新时间|string|√|| + |total|数据总量|integer|√|| diff --git a/tests/data/embedding/3.query.md b/tests/data/embedding/3.query.md new file mode 100644 index 000000000..6026899d7 --- /dev/null +++ b/tests/data/embedding/3.query.md @@ -0,0 +1 @@ +根据国家code查询国家业务线列表 \ No newline at end of file diff --git a/tests/metagpt/rag/__init__.py b/tests/metagpt/rag/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/metagpt/rag/test_large_pdf.py b/tests/metagpt/rag/test_large_pdf.py new file mode 100644 index 000000000..4f343aa87 --- /dev/null +++ b/tests/metagpt/rag/test_large_pdf.py @@ -0,0 +1,55 @@ +import pytest + +from metagpt.config2 import Config +from metagpt.const import TEST_DATA_PATH +from metagpt.rag.engines import SimpleEngine +from metagpt.rag.factories.embedding import RAGEmbeddingFactory +from metagpt.utils.common import aread + + +@pytest.mark.skip +@pytest.mark.parametrize( + ("knowledge_filename", "query_filename", "answer_filename"), + [ + ( + TEST_DATA_PATH / "embedding/2.knowledge.md", + TEST_DATA_PATH / "embedding/2.query.md", + TEST_DATA_PATH / "embedding/2.answer.md", + ), + ( + TEST_DATA_PATH / "embedding/3.knowledge.md", + TEST_DATA_PATH / "embedding/3.query.md", + TEST_DATA_PATH / "embedding/3.answer.md", + ), + ], +) +@pytest.mark.asyncio +async def test_large_pdf(knowledge_filename, query_filename, answer_filename): + Config.default(reload=True) # `config.embedding.model = "text-embedding-ada-002"` changes the cache. + + engine = SimpleEngine.from_docs( + input_files=[knowledge_filename], + ) + + query = await aread(filename=query_filename) + rsp = await engine.aretrieve(query) + assert rsp + + config = Config.default() + config.embedding.model = "text-embedding-ada-002" + factory = RAGEmbeddingFactory(config) + embedding = factory.get_rag_embedding() + answer = await aread(filename=answer_filename) + answer_embedding = await embedding.aget_text_embedding(answer) + similarity = 0 + for i in rsp: + rsp_embedding = await embedding.aget_query_embedding(i.text) + v = embedding.similarity(answer_embedding, rsp_embedding) + similarity = max(similarity, v) + + print(similarity) + assert similarity > 0.9 + + +if __name__ == "__main__": + pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/tools/libs/test_index_repo.py b/tests/metagpt/tools/libs/test_index_repo.py new file mode 100644 index 000000000..3cc8ad406 --- /dev/null +++ b/tests/metagpt/tools/libs/test_index_repo.py @@ -0,0 +1,32 @@ +import shutil + +import pytest + +from metagpt.const import DEFAULT_WORKSPACE_ROOT, TEST_DATA_PATH +from metagpt.tools.libs.index_repo import IndexRepo + + +@pytest.mark.asyncio +@pytest.mark.parametrize(("path", "query"), [(TEST_DATA_PATH / "requirements", "业务线")]) +async def test_index_repo(path, query): + index_path = DEFAULT_WORKSPACE_ROOT / ".index" + repo = IndexRepo(persist_path=str(index_path), root_path=str(path), min_token_count=0) + await repo.add([path]) + await repo.add([path]) + assert index_path.exists() + + rsp = await repo.search(query) + assert rsp + + repo2 = IndexRepo(persist_path=str(index_path), root_path=str(path), min_token_count=0) + rsp2 = await repo2.search(query) + assert rsp2 + + merged_rsp = await repo.merge(query=query, indices_list=[rsp, rsp2]) + assert merged_rsp + + shutil.rmtree(index_path) + + +if __name__ == "__main__": + pytest.main([__file__, "-s"])