mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-05 14:55:18 +02:00
Merge branch 'feature/rfc258' into 'mgx_ops'
Feat: RFC-258-Editor.search设计方案对应的Index Repo See merge request pub/MetaGPT!362
This commit is contained in:
commit
ac29811c2f
18 changed files with 648 additions and 20 deletions
|
|
@ -2,7 +2,6 @@ from abc import abstractmethod
|
|||
from typing import Optional, Union
|
||||
|
||||
from metagpt.base.base_serialization import BaseSerialization
|
||||
from metagpt.schema import Message
|
||||
|
||||
|
||||
class BaseRole(BaseSerialization):
|
||||
|
|
@ -25,13 +24,13 @@ class BaseRole(BaseSerialization):
|
|||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def react(self) -> Message:
|
||||
async def react(self) -> "Message":
|
||||
"""Entry to one of three strategies by which Role reacts to the observed Message."""
|
||||
|
||||
@abstractmethod
|
||||
async def run(self, with_message: Optional[Union[str, Message, list[str]]] = None) -> Optional[Message]:
|
||||
async def run(self, with_message: Optional[Union[str, "Message", list[str]]] = None) -> Optional["Message"]:
|
||||
"""Observe, and think and act based on the results of the observation."""
|
||||
|
||||
@abstractmethod
|
||||
def get_memories(self, k: int = 0) -> list[Message]:
|
||||
def get_memories(self, k: int = 0) -> list["Message"]:
|
||||
"""Return the most recent k memories of this role."""
|
||||
|
|
|
|||
|
|
@ -20,11 +20,13 @@ class EmbeddingConfig(YamlModel):
|
|||
---------
|
||||
api_type: "openai"
|
||||
api_key: "YOU_API_KEY"
|
||||
dimensions: "YOUR_MODEL_DIMENSIONS"
|
||||
|
||||
api_type: "azure"
|
||||
api_key: "YOU_API_KEY"
|
||||
base_url: "YOU_BASE_URL"
|
||||
api_version: "YOU_API_VERSION"
|
||||
dimensions: "YOUR_MODEL_DIMENSIONS"
|
||||
|
||||
api_type: "gemini"
|
||||
api_key: "YOU_API_KEY"
|
||||
|
|
@ -32,6 +34,7 @@ class EmbeddingConfig(YamlModel):
|
|||
api_type: "ollama"
|
||||
base_url: "YOU_BASE_URL"
|
||||
model: "YOU_MODEL"
|
||||
dimensions: "YOUR_MODEL_DIMENSIONS"
|
||||
"""
|
||||
|
||||
api_type: Optional[EmbeddingType] = None
|
||||
|
|
@ -41,6 +44,7 @@ class EmbeddingConfig(YamlModel):
|
|||
|
||||
model: Optional[str] = None
|
||||
embed_batch_size: Optional[int] = None
|
||||
dimensions: Optional[int] = None # output dimension of embedding model
|
||||
|
||||
@field_validator("api_type", mode="before")
|
||||
@classmethod
|
||||
|
|
|
|||
|
|
@ -27,7 +27,6 @@ from metagpt.configs.llm_config import LLMConfig
|
|||
from metagpt.const import IMAGES, LLM_API_TIMEOUT, USE_CONFIG_TIMEOUT
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.constant import MULTI_MODAL_MODELS
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.common import log_and_reraise
|
||||
from metagpt.utils.cost_manager import CostManager, Costs
|
||||
from metagpt.utils.token_counter import TOKEN_MAX
|
||||
|
|
@ -80,7 +79,7 @@ class BaseLLM(ABC):
|
|||
def support_image_input(self) -> bool:
|
||||
return any([m in self.config.model for m in MULTI_MODAL_MODELS])
|
||||
|
||||
def format_msg(self, messages: Union[str, Message, list[dict], list[Message], list[str]]) -> list[dict]:
|
||||
def format_msg(self, messages: Union[str, "Message", list[dict], list["Message"], list[str]]) -> list[dict]:
|
||||
"""convert messages to list[dict]."""
|
||||
from metagpt.schema import Message
|
||||
|
||||
|
|
@ -173,7 +172,9 @@ class BaseLLM(ABC):
|
|||
context.append(self._assistant_msg(rsp_text))
|
||||
return self._extract_assistant_rsp(context)
|
||||
|
||||
async def aask_code(self, messages: Union[str, Message, list[dict]], timeout=USE_CONFIG_TIMEOUT, **kwargs) -> dict:
|
||||
async def aask_code(
|
||||
self, messages: Union[str, "Message", list[dict]], timeout=USE_CONFIG_TIMEOUT, **kwargs
|
||||
) -> dict:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
|
|
|
|||
|
|
@ -22,7 +22,6 @@ from metagpt.const import USE_CONFIG_TIMEOUT
|
|||
from metagpt.logs import log_llm_stream, logger
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.schema import Message
|
||||
|
||||
|
||||
class GeminiGenerativeModel(GenerativeModel):
|
||||
|
|
@ -73,7 +72,7 @@ class GeminiLLM(BaseLLM):
|
|||
def _system_msg(self, msg: str) -> dict[str, str]:
|
||||
return {"role": "user", "parts": [msg]}
|
||||
|
||||
def format_msg(self, messages: Union[str, Message, list[dict], list[Message], list[str]]) -> list[dict]:
|
||||
def format_msg(self, messages: Union[str, "Message", list[dict], list["Message"], list[str]]) -> list[dict]:
|
||||
"""convert messages to list[dict]."""
|
||||
from metagpt.schema import Message
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,8 @@
|
|||
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Optional, Union
|
||||
from pathlib import Path
|
||||
from typing import Any, List, Optional, Set, Union
|
||||
|
||||
import fsspec
|
||||
from llama_index.core import SimpleDirectoryReader
|
||||
|
|
@ -78,6 +79,7 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
callback_manager=callback_manager,
|
||||
)
|
||||
self._transformations = transformations or self._default_transformations()
|
||||
self._filenames = set()
|
||||
|
||||
@classmethod
|
||||
def from_docs(
|
||||
|
|
@ -192,11 +194,11 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
self._try_reconstruct_obj(nodes)
|
||||
return nodes
|
||||
|
||||
def add_docs(self, input_files: list[str]):
|
||||
def add_docs(self, input_files: List[Union[str, Path]]):
|
||||
"""Add docs to retriever. retriever must has add_nodes func."""
|
||||
self._ensure_retriever_modifiable()
|
||||
|
||||
documents = SimpleDirectoryReader(input_files=input_files).load_data()
|
||||
documents = SimpleDirectoryReader(input_files=[str(i) for i in input_files]).load_data()
|
||||
self._fix_document_metadata(documents)
|
||||
|
||||
nodes = run_transformations(documents, transformations=self._transformations)
|
||||
|
|
@ -227,6 +229,24 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
|
||||
return self.retriever.clear(**kwargs)
|
||||
|
||||
def delete_docs(self, input_files: List[Union[str, Path]]):
|
||||
"""Delete documents from the index and document store.
|
||||
|
||||
Args:
|
||||
input_files (List[Union[str, Path]]): A list of file paths or file names to be deleted.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If the method is not implemented.
|
||||
"""
|
||||
exists_filenames = set()
|
||||
filenames = {str(i) for i in input_files}
|
||||
for doc_id, info in self.retriever._index.ref_doc_info.items():
|
||||
if info.metadata.get("file_path") in filenames:
|
||||
exists_filenames.add(doc_id)
|
||||
|
||||
for doc_id in exists_filenames:
|
||||
self.retriever._index.delete_ref_doc(doc_id, delete_from_docstore=True)
|
||||
|
||||
@staticmethod
|
||||
def get_obj_nodes(objs: Optional[list[RAGObject]] = None) -> list[ObjectNode]:
|
||||
"""Converts a list of RAGObjects to a list of ObjectNodes."""
|
||||
|
|
@ -333,3 +353,7 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
@staticmethod
|
||||
def _default_transformations():
|
||||
return [SentenceSplitter()]
|
||||
|
||||
@property
|
||||
def filenames(self) -> Set[str]:
|
||||
return self._filenames
|
||||
|
|
|
|||
|
|
@ -5,9 +5,6 @@ from typing import Any, Optional
|
|||
|
||||
from llama_index.core.embeddings import BaseEmbedding
|
||||
from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding
|
||||
from llama_index.embeddings.gemini import GeminiEmbedding
|
||||
from llama_index.embeddings.ollama import OllamaEmbedding
|
||||
from llama_index.embeddings.openai import OpenAIEmbedding
|
||||
|
||||
from metagpt.config2 import Config
|
||||
from metagpt.configs.embedding_config import EmbeddingType
|
||||
|
|
@ -49,7 +46,9 @@ class RAGEmbeddingFactory(GenericFactory):
|
|||
|
||||
raise TypeError("To use RAG, please set your embedding in config2.yaml.")
|
||||
|
||||
def _create_openai(self) -> OpenAIEmbedding:
|
||||
def _create_openai(self) -> "OpenAIEmbedding":
|
||||
from llama_index.embeddings.openai import OpenAIEmbedding
|
||||
|
||||
params = dict(
|
||||
api_key=self.config.embedding.api_key or self.config.llm.api_key,
|
||||
api_base=self.config.embedding.base_url or self.config.llm.base_url,
|
||||
|
|
@ -70,7 +69,9 @@ class RAGEmbeddingFactory(GenericFactory):
|
|||
|
||||
return AzureOpenAIEmbedding(**params)
|
||||
|
||||
def _create_gemini(self) -> GeminiEmbedding:
|
||||
def _create_gemini(self) -> "GeminiEmbedding":
|
||||
from llama_index.embeddings.gemini import GeminiEmbedding
|
||||
|
||||
params = dict(
|
||||
api_key=self.config.embedding.api_key,
|
||||
api_base=self.config.embedding.base_url,
|
||||
|
|
@ -80,7 +81,9 @@ class RAGEmbeddingFactory(GenericFactory):
|
|||
|
||||
return GeminiEmbedding(**params)
|
||||
|
||||
def _create_ollama(self) -> OllamaEmbedding:
|
||||
def _create_ollama(self) -> "OllamaEmbedding":
|
||||
from llama_index.embeddings.ollama import OllamaEmbedding
|
||||
|
||||
params = dict(
|
||||
base_url=self.config.embedding.base_url,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -13,7 +13,6 @@ from llama_index.core.llms.callbacks import llm_completion_callback
|
|||
from pydantic import Field
|
||||
|
||||
from metagpt.config2 import Config
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.utils.async_helper import NestAsyncio
|
||||
from metagpt.utils.token_counter import TOKEN_MAX
|
||||
|
|
@ -79,4 +78,6 @@ class RAGLLM(CustomLLM):
|
|||
|
||||
def get_rag_llm(model_infer: BaseLLM = None) -> RAGLLM:
|
||||
"""Get llm that can be used by LlamaIndex."""
|
||||
from metagpt.llm import LLM
|
||||
|
||||
return RAGLLM(model_infer=model_infer or LLM())
|
||||
|
|
|
|||
264
metagpt/tools/libs/index_repo.py
Normal file
264
metagpt/tools/libs/index_repo.py
Normal file
|
|
@ -0,0 +1,264 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Set, Union
|
||||
|
||||
import tiktoken
|
||||
from llama_index.core.base.embeddings.base import BaseEmbedding
|
||||
from llama_index.core.schema import NodeWithScore
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from metagpt.config2 import Config
|
||||
from metagpt.logs import logger
|
||||
from metagpt.rag.engines import SimpleEngine
|
||||
from metagpt.rag.factories.embedding import RAGEmbeddingFactory
|
||||
from metagpt.rag.schema import FAISSIndexConfig, FAISSRetrieverConfig, LLMRankerConfig
|
||||
from metagpt.utils.common import aread, awrite, generate_fingerprint, list_files
|
||||
from metagpt.utils.repo_to_markdown import is_text_file
|
||||
|
||||
|
||||
class TextScore(BaseModel):
|
||||
filename: str
|
||||
text: str
|
||||
score: Optional[float] = None
|
||||
|
||||
|
||||
class IndexRepo(BaseModel):
|
||||
persist_path: str # The persist path of the index repo, {DEFAULT_WORKSPACE_ROOT}/.index/{chat_id or 'uploads'}/
|
||||
root_path: str # `/data/uploads` or r`/data/chats/\d+`, the root path of files indexed by the index repo.
|
||||
fingerprint_filename: str = "fingerprint.json"
|
||||
model: Optional[str] = None
|
||||
min_token_count: int = 10000
|
||||
max_token_count: int = 100000000
|
||||
recall_count: int = 5
|
||||
embedding: Optional[BaseEmbedding] = Field(default=None, exclude=True)
|
||||
fingerprints: Dict[str, str] = Field(default_factory=dict)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _update_fingerprints(self) -> "IndexRepo":
|
||||
"""Load fingerprints from the fingerprint file if not already loaded.
|
||||
|
||||
Returns:
|
||||
IndexRepo: The updated IndexRepo instance.
|
||||
"""
|
||||
if not self.fingerprints:
|
||||
filename = Path(self.persist_path) / self.fingerprint_filename
|
||||
if not filename.exists():
|
||||
return self
|
||||
with open(str(filename), "r") as reader:
|
||||
self.fingerprints = json.load(reader)
|
||||
return self
|
||||
|
||||
async def search(
|
||||
self, query: str, filenames: Optional[List[Path]] = None
|
||||
) -> Optional[List[Union[NodeWithScore, TextScore]]]:
|
||||
"""Search for documents related to the given query.
|
||||
|
||||
Args:
|
||||
query (str): The search query.
|
||||
filenames (Optional[List[Path]]): A list of filenames to filter the search.
|
||||
|
||||
Returns:
|
||||
Optional[List[Union[NodeWithScore, TextScore]]]: A list of search results containing NodeWithScore or TextScore.
|
||||
"""
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
result: List[Union[NodeWithScore, TextScore]] = []
|
||||
filenames, _ = await self._filter(filenames)
|
||||
filter_filenames = set()
|
||||
for i in filenames:
|
||||
content = await aread(filename=i)
|
||||
token_count = len(encoding.encode(content))
|
||||
if not self._is_buildable(token_count):
|
||||
result.append(TextScore(filename=str(i), text=content))
|
||||
continue
|
||||
file_fingerprint = generate_fingerprint(content)
|
||||
if self.fingerprints.get(str(i)) != file_fingerprint:
|
||||
logger.error(f'file: "{i}" changed but not indexed')
|
||||
continue
|
||||
filter_filenames.add(str(i))
|
||||
nodes = await self._search(query=query, filters=filter_filenames)
|
||||
return result + nodes
|
||||
|
||||
async def merge(
|
||||
self, query: str, indices_list: List[List[Union[NodeWithScore, TextScore]]]
|
||||
) -> List[Union[NodeWithScore, TextScore]]:
|
||||
"""Merge results from multiple indices based on the query.
|
||||
|
||||
Args:
|
||||
query (str): The search query.
|
||||
indices_list (List[List[Union[NodeWithScore, TextScore]]]): A list of result lists from different indices.
|
||||
|
||||
Returns:
|
||||
List[Union[NodeWithScore, TextScore]]: A list of merged results sorted by similarity.
|
||||
"""
|
||||
if not self.embedding:
|
||||
config = Config.default()
|
||||
if self.model:
|
||||
config.embedding.model = self.model
|
||||
factory = RAGEmbeddingFactory(config)
|
||||
self.embedding = factory.get_rag_embedding()
|
||||
|
||||
scores = []
|
||||
query_embedding = await self.embedding.aget_text_embedding(query)
|
||||
flat_nodes = [node for indices in indices_list for node in indices]
|
||||
for i in flat_nodes:
|
||||
text_embedding = await self.embedding.aget_text_embedding(i.text)
|
||||
similarity = self.embedding.similarity(query_embedding, text_embedding)
|
||||
scores.append((similarity, i))
|
||||
scores.sort(key=lambda x: x[0], reverse=True)
|
||||
return [i[1] for i in scores][: self.recall_count]
|
||||
|
||||
async def add(self, paths: List[Path]):
|
||||
"""Add new documents to the index.
|
||||
|
||||
Args:
|
||||
paths (List[Path]): A list of paths to the documents to be added.
|
||||
"""
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
filenames, _ = await self._filter(paths)
|
||||
filter_filenames = []
|
||||
delete_filenames = []
|
||||
for i in filenames:
|
||||
content = await aread(filename=i)
|
||||
if not self._is_fingerprint_changed(filename=i, content=content):
|
||||
continue
|
||||
token_count = len(encoding.encode(content))
|
||||
if self._is_buildable(token_count):
|
||||
filter_filenames.append(i)
|
||||
logger.debug(f"{i} is_buildable: {token_count}, {self.min_token_count}~{self.max_token_count}")
|
||||
else:
|
||||
delete_filenames.append(i)
|
||||
logger.debug(f"{i} not is_buildable: {token_count}, {self.min_token_count}~{self.max_token_count}")
|
||||
await self._add_batch(filenames=filter_filenames, delete_filenames=delete_filenames)
|
||||
|
||||
async def _add_batch(self, filenames: List[Union[str, Path]], delete_filenames: List[Union[str, Path]]):
|
||||
"""Add and remove documents in a batch operation.
|
||||
|
||||
Args:
|
||||
filenames (List[Union[str, Path]]): List of filenames to add.
|
||||
delete_filenames (List[Union[str, Path]]): List of filenames to delete.
|
||||
"""
|
||||
if not filenames:
|
||||
return
|
||||
logger.info(f"update index repo, add {filenames}, remove {delete_filenames}")
|
||||
engine = None
|
||||
if Path(self.persist_path).exists():
|
||||
logger.debug(f"load index from {self.persist_path}")
|
||||
engine = SimpleEngine.from_index(
|
||||
index_config=FAISSIndexConfig(persist_path=self.persist_path),
|
||||
retriever_configs=[FAISSRetrieverConfig()],
|
||||
)
|
||||
try:
|
||||
engine.delete_docs(filenames + delete_filenames)
|
||||
logger.debug(f"delete docs {filenames + delete_filenames}")
|
||||
engine.add_docs(input_files=filenames)
|
||||
logger.debug(f"add docs {filenames}")
|
||||
except NotImplementedError as e:
|
||||
logger.debug(f"{e}")
|
||||
filenames = list(set([str(i) for i in filenames] + list(self.fingerprints.keys())))
|
||||
engine = None
|
||||
logger.info(f"{e}. Rebuild all.")
|
||||
if not engine:
|
||||
engine = SimpleEngine.from_docs(
|
||||
input_files=[str(i) for i in filenames],
|
||||
retriever_configs=[FAISSRetrieverConfig()],
|
||||
ranker_configs=[LLMRankerConfig()],
|
||||
)
|
||||
logger.debug(f"add docs {filenames}")
|
||||
engine.persist(persist_dir=self.persist_path)
|
||||
for i in filenames:
|
||||
content = await aread(i)
|
||||
fp = generate_fingerprint(content)
|
||||
self.fingerprints[str(i)] = fp
|
||||
await awrite(filename=Path(self.persist_path) / self.fingerprint_filename, data=json.dumps(self.fingerprints))
|
||||
|
||||
def __str__(self):
|
||||
"""Return a string representation of the IndexRepo.
|
||||
|
||||
Returns:
|
||||
str: The filename of the index repository.
|
||||
"""
|
||||
return f"{self.persist_path}"
|
||||
|
||||
def _is_buildable(self, token_count: int) -> bool:
|
||||
"""Check if the token count is within the buildable range.
|
||||
|
||||
Args:
|
||||
token_count (int): The number of tokens in the content.
|
||||
|
||||
Returns:
|
||||
bool: True if buildable, False otherwise.
|
||||
"""
|
||||
if token_count < self.min_token_count or token_count > self.max_token_count:
|
||||
return False
|
||||
return True
|
||||
|
||||
async def _filter(self, filenames: Optional[List[Union[str, Path]]] = None) -> (List[Path], List[Path]):
|
||||
"""Filter the provided filenames to only include valid text files.
|
||||
|
||||
Args:
|
||||
filenames (Optional[List[Union[str, Path]]]): List of filenames to filter.
|
||||
|
||||
Returns:
|
||||
Tuple[List[Path], List[Path]]: A tuple containing a list of valid pathnames and a list of excluded paths.
|
||||
"""
|
||||
root_path = Path(self.root_path).absolute()
|
||||
if not filenames:
|
||||
filenames = [root_path]
|
||||
pathnames = []
|
||||
excludes = []
|
||||
for i in filenames:
|
||||
path = Path(i).absolute()
|
||||
if not path.is_relative_to(root_path):
|
||||
excludes.append(path)
|
||||
logger.debug(f"{path} not is_relative_to {root_path})")
|
||||
continue
|
||||
if not path.is_dir():
|
||||
is_text, _ = await is_text_file(path)
|
||||
if is_text:
|
||||
pathnames.append(path)
|
||||
continue
|
||||
subfiles = list_files(path)
|
||||
for j in subfiles:
|
||||
is_text, _ = await is_text_file(j)
|
||||
if is_text:
|
||||
pathnames.append(j)
|
||||
|
||||
logger.debug(f"{pathnames}, excludes:{excludes})")
|
||||
return pathnames, excludes
|
||||
|
||||
async def _search(self, query: str, filters: Set[str]) -> List[NodeWithScore]:
|
||||
"""Perform a search for the given query using the index.
|
||||
|
||||
Args:
|
||||
query (str): The search query.
|
||||
filters (Set[str]): A set of filenames to filter the search results.
|
||||
|
||||
Returns:
|
||||
List[NodeWithScore]: A list of nodes with scores matching the query.
|
||||
"""
|
||||
if not Path(self.persist_path).exists():
|
||||
return []
|
||||
engine = SimpleEngine.from_index(
|
||||
index_config=FAISSIndexConfig(persist_path=self.persist_path), retriever_configs=[FAISSRetrieverConfig()]
|
||||
)
|
||||
rsp = await engine.aretrieve(query)
|
||||
return [i for i in rsp if i.metadata.get("file_path") in filters]
|
||||
|
||||
def _is_fingerprint_changed(self, filename: Union[str, Path], content: str) -> bool:
|
||||
"""Check if the fingerprint of the given document content has changed.
|
||||
|
||||
Args:
|
||||
filename (Union[str, Path]): The filename of the document.
|
||||
content (str): The content of the document.
|
||||
|
||||
Returns:
|
||||
bool: True if the fingerprint has changed, False otherwise.
|
||||
"""
|
||||
old_fp = self.fingerprints.get(str(filename))
|
||||
if not old_fp:
|
||||
return True
|
||||
fp = generate_fingerprint(content)
|
||||
return old_fp != fp
|
||||
|
|
@ -16,6 +16,7 @@ import base64
|
|||
import contextlib
|
||||
import csv
|
||||
import functools
|
||||
import hashlib
|
||||
import importlib
|
||||
import inspect
|
||||
import json
|
||||
|
|
@ -889,7 +890,7 @@ async def get_mime_type(filename: str | Path, force_read: bool = False) -> str:
|
|||
}
|
||||
|
||||
try:
|
||||
stdout, stderr, _ = await shell_execute(f"file --mime-type {str(filename)}")
|
||||
stdout, stderr, _ = await shell_execute(f"file --mime-type '{str(filename)}'")
|
||||
if stderr:
|
||||
logger.debug(f"file:{filename}, error:{stderr}")
|
||||
return guess_mime_type
|
||||
|
|
@ -1175,3 +1176,23 @@ def rectify_pathname(path: Union[str, Path], default_filename: str) -> Path:
|
|||
else:
|
||||
output_pathname.parent.mkdir(parents=True, exist_ok=True)
|
||||
return output_pathname
|
||||
|
||||
|
||||
def generate_fingerprint(text: str) -> str:
|
||||
"""
|
||||
Generate a fingerprint for the given text
|
||||
|
||||
Args:
|
||||
text (str): The text for which the fingerprint needs to be generated
|
||||
|
||||
Returns:
|
||||
str: The fingerprint value of the text
|
||||
"""
|
||||
text_bytes = text.encode("utf-8")
|
||||
|
||||
# calculate SHA-256 hash
|
||||
sha256 = hashlib.sha256()
|
||||
sha256.update(text_bytes)
|
||||
fingerprint = sha256.hexdigest()
|
||||
|
||||
return fingerprint
|
||||
|
|
|
|||
2
tests/data/embedding/2.answer.md
Normal file
2
tests/data/embedding/2.answer.md
Normal file
|
|
@ -0,0 +1,2 @@
|
|||
检索结果
|
||||
法务查询者可根据国际小超人钉钉小程序UI上的滚筒切换业务线 这张图片展示了一个移动应用的界面,界面标题为“法律意见详情”。用户可以根据具体情况切换业务线。界面中有多个字段,包括“国家名称”、“国家情况描述”、“业务线”、“产品法规分析”和“签约主体”。第一张截图显示了详细的法律情报信息,包含区域名称、区域情况描述、业务线和产品法规概述等字段。第二张截图显示了“法律意见详情”界面,其中列出了国家名称、国家情况描述、业务线、产品法规分析和签约主体。第三张截图与第二张相似,但显示了选项的可选择状态。最下方有“取消”和“确定”的按钮。 法务查询者从国家详情中的业务线名列表中选出要查看的业务线。
|
||||
25
tests/data/embedding/2.knowledge.md
Normal file
25
tests/data/embedding/2.knowledge.md
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
## Textual User Requirements
|
||||
|
||||
### 3.2. 首页
|
||||
|
||||
首页有两个分区,上面部分是法律意见检索栏。
|
||||
|
||||
法务查询者第一次进入国际小超人钉钉小程序展示引导页,以后进入不再展示,点击「我知道了」引导页消失。
|
||||
|
||||
#### 首页
|
||||

|
||||
这是一个名为“法务小超人”的移动应用程序的界面截图。界面顶部显示了应用名称和一个可切换语言的按钮“English”。在界面中间部分,有一个标题“法律意见查询”,以及一个搜索框,提示输入国家名称以查询法律意见。下方显示已收录法律意见8394篇。界面下半部分是“法务 Q&A”部分,列出了一些法律相关的选项,例如“国际法务接入口人”、“国内法务接入口人”、“国际法律协议合同办理指引”和“国内法律协议合同办理指引”。界面底部有三个导航按钮,分别是“首页”、“模板”和“我的”。
|
||||
|
||||
#### 按国家名维度搜索
|
||||
法务查询者在国际小超人钉钉小程序的搜索框中进行检索时采用typeahead,只能下拉选择法务中台中有的国家名称。
|
||||

|
||||
在这张图像中,用户正在一个名为“法律意见查询”的应用中进行国家名称的搜索。用户在搜索框中输入国家名称时,系统会提供下拉建议。这些建议基于 typeahead 功能,从法务中台中筛选出匹配的国家名称供用户选择。目前,搜索结果包含了“中国”和“菲律宾”两个具体的国家名称,其它显示为“国家名”。用户可以通过下拉菜单快速选择所需的国家名称。
|
||||
|
||||
#### 检索结果
|
||||
法务查询者可根据国际小超人钉钉小程序UI上的滚筒切换业务线
|
||||

|
||||
这张图片展示了一个移动应用的界面,界面标题为“法律意见详情”。用户可以根据具体情况切换业务线。界面中有多个字段,包括“国家名称”、“国家情况描述”、“业务线”、“产品法规分析”和“签约主体”。第一张截图显示了详细的法律情报信息,包含区域名称、区域情况描述、业务线和产品法规概述等字段。第二张截图显示了“法律意见详情”界面,其中列出了国家名称、国家情况描述、业务线、产品法规分析和签约主体。第三张截图与第二张相似,但显示了选项的可选择状态。最下方有“取消”和“确定”的按钮。
|
||||
法务查询者从国家详情中的业务线名列表中选出要查看的业务线。
|
||||
|
||||
#### 查看法律意见详情
|
||||
国际小超人钉钉小程序用国家代码和业务代码做参数,查询法律意见详情,然后将法律意见详情展示给法务查询者。
|
||||
1
tests/data/embedding/2.query.md
Normal file
1
tests/data/embedding/2.query.md
Normal file
|
|
@ -0,0 +1 @@
|
|||
业务线UI有哪些操作?
|
||||
7
tests/data/embedding/3.answer.md
Normal file
7
tests/data/embedding/3.answer.md
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
国家/区域导游详情 & 法律意见详情 查询
|
||||
Description:根据国家code查询国家/区域导游信息详情
|
||||
ID: 8
|
||||
HTTP METHOD: GET
|
||||
Endpoint: /contract/country/navigate.json
|
||||
Input Parameters: |名称|描述|类型(长度)|必选|备注| | :- | :- | :-: | :- | :- | |countryCode|国家code|string|√||
|
||||
Returns: |名称|描述|类型(长度)|必选|备注| | :- | :- | :-: | :- | :- | |success|业务处理成功true,否则false|boolean|√|只判断这个属性即可| |message|错误信息,可以用来提示|string|√|| |code|返回状态码|string|√|| |data|国家/区域导游详情|object|√|| |-> country||||| |-> -> id|id|integer|√|| |-> -> country|国家code|string|√|| |-> -> countryName|国家中文名称|string|√|| |-> -> countryNameEn|国家英文名称|string|√|| |-> -> content|国家导游中文详情json数组,具体格式见下示例|list of object|√|| |-> -> -> title|标题|object|√|| |-> -> -> -> title|中文标题|string||| |-> -> -> -> titleEn|英文标题|string||| |-> -> -> contentList|标题下面的文字描述列表|list of object|√|| |-> -> -> -> detail|内容中文详情|string|√|| |-> -> -> -> detailEn|内容英文详情|string|√|| |-> -> -> -> url|超链接|string||| |-> legal|法务信息|object||| |-> -> country|国家code|string|√|| |-> -> businessList|业务线列表|list of object||| |-> -> -> id|id|integer||新增时不传,修改时传递| |-> -> -> business|业务线code|string|√|| |-> -> -> businessName|业务线中文名称|string|√|| |-> -> -> businessNameEn|业务线英文名称|string|√|| |-> -> -> content|业务线json,具体如下|object|√|| |-> -> -> -> detailEn|具体的详情英文内容|string|√|| |-> -> -> -> detail|具体的详情内容|string|√||
|
||||
189
tests/data/embedding/3.knowledge.md
Normal file
189
tests/data/embedding/3.knowledge.md
Normal file
|
|
@ -0,0 +1,189 @@
|
|||
## Interfaces
|
||||
- 用户登录
|
||||
- Description: 用户从小程序/微应用发起请求,需要验证用户的合法身份才能正常处理。
|
||||
- ID: 1
|
||||
- HTTP METHOD: GET
|
||||
- Endpoint: `/sup/login.json`
|
||||
- Input Parameters:
|
||||
|名称|描述|类型(长度)|必选|备注|
|
||||
| :- | :- | :-: | :- | :- |
|
||||
|authCode|用户临时免登授权码|String(64)|√||
|
||||
|loginTypeEnum|登录类型|String(20)|√||
|
||||
|authCorpId|用户所在企业/组织id|String(64)||微应用免登时传递|
|
||||
|app|应用标识|String(3)|√||
|
||||
- Returns:
|
||||
|名称|描述|类型(长度)|必选|备注|
|
||||
| :- | :- | :-: | :- | :- |
|
||||
|success|业务处理成功与否,成功true,否则false|boolean|√|只判断这个属性即可|
|
||||
|message|错误信息,可以用来提示|string|√||
|
||||
|code|返回状态码|string|√||
|
||||
|data|用户的sessionId|string|√||
|
||||
- 根据sessionId查询用户详细信息
|
||||
- Description: 查询当前用户的详细信息,如 staffId,unionId,name,avatar等信息
|
||||
- ID: 2
|
||||
- HTTP METHOD: GET
|
||||
- Endpoint: `/sup/user.json`
|
||||
- Input Parameters:
|
||||
|名称|描述|类型(长度)|必选|备注|
|
||||
| :- | :- | :-: | :- | :- |
|
||||
|NDA_SESSION|用户sessionId|String(64)|√||
|
||||
- Returns:
|
||||
|名称|描述|类型(长度)|必选|备注|
|
||||
| :- | :- | :-: | :- | :- |
|
||||
|success|业务处理成功与否,成功true,否则false|boolean|√|只判断这个属性即可|
|
||||
|message|错误信息,可以用来提示|string|√||
|
||||
|code|返回状态码|string|√||
|
||||
|data|用户的详细信息|object|√||
|
||||
|-> corpId|当前用户企业 钉钉ID(小程序端会拿不到该信息)|string|√||
|
||||
|-> corpName|当前用户企业名称(小程序端会拿不到该信息)|string|√||
|
||||
|-> staffId|员工在当前企业内的唯一标识,也称staffId(小程序端会拿不到该信息)|string|√||
|
||||
|-> unionId|员工在当前开发者企业账号范围内的唯一标识,系统生成,固定值,不会改变。|string|√||
|
||||
|-> name|当前用户的名称(小程序端会拿不到该信息)|string|√||
|
||||
|-> avatar|头像图片URL|string|√||
|
||||
- 查询国家情况描述
|
||||
- Description: 根据国家code查询国家情况描述
|
||||
- ID: 3
|
||||
- HTTP METHOD: GET
|
||||
- Endpoint: `/sup/country/detail.json`
|
||||
- Input Parameters:
|
||||
|名称|描述|类型(长度)|必选|备注|
|
||||
| :- | :- | :-: | :- | :- |
|
||||
|countryCode|国家code|string|√||
|
||||
- Returns:
|
||||
|名称|描述|类型(长度)|必选|备注|
|
||||
| :- | :- | :-: | :- | :- |
|
||||
|success|业务处理成功true,否则false|boolean|√|只判断这个属性即可|
|
||||
|message|错误信息,可以用来提示|string|√||
|
||||
|code|返回状态码|string|√||
|
||||
|data|国家情况描述|object|√||
|
||||
|-> id|id|integer|√||
|
||||
|-> countryName|国家名称|string|√||
|
||||
|-> countryCode|国家code|string|√||
|
||||
|-> detail|产品法规分析|string|√||
|
||||
- 查询产品法规分析(法律意见详情)
|
||||
- Description: 根据国家和业务线查询产品法规分析
|
||||
- ID: 4
|
||||
- HTTP METHOD: GET
|
||||
- Endpoint: `/sup/legal/detail.json`
|
||||
- Input Parameters:
|
||||
|名称|描述|类型(长度)|必选|备注|
|
||||
| :- | :- | :-: | :- | :- |
|
||||
|countryCode|国家code|string|√||
|
||||
|businessCode|业务线code|string|√||
|
||||
- Returns:
|
||||
|名称|描述|类型(长度)|必选|备注|
|
||||
| :- | :- | :-: | :- | :- |
|
||||
|success|业务处理成功true,否则false|boolean|√|只判断这个属性即可|
|
||||
|message|错误信息,可以用来提示|string|√||
|
||||
|code|返回状态码|string|√||
|
||||
|data|法律意见详情|object|√||
|
||||
|-> id|id|integer|√||
|
||||
|-> countryName|国家名称|string|√||
|
||||
|-> countryCode|国家code|string|√||
|
||||
|-> businessLine|业务线|string|√||
|
||||
|-> businessCode|业务线code|string|√||
|
||||
|-> detail|产品法规分析|string|√||
|
||||
|-> signEntity|签约主体|string|√||
|
||||
- 查询法律意见总数
|
||||
- Description: 法律意见总数查询
|
||||
- ID: 5
|
||||
- HTTP METHOD: GET
|
||||
- Endpoint: `/sup/legal/count.json`
|
||||
- Input Parameters:
|
||||
|名称|描述|类型(长度)|必选|备注|
|
||||
| :- | :- | :-: | :- | :- |
|
||||
- Returns:
|
||||
|名称|描述|类型(长度)|必选|备注|
|
||||
| :- | :- | :-: | :- | :- |
|
||||
|success|业务处理成功true,否则false|boolean|√|只判断这个属性即可|
|
||||
|message|错误信息,可以用来提示|string|√||
|
||||
|code|返回状态码|string|√||
|
||||
|data|总数|integer|√||
|
||||
- 查询所有国家和业务线信息列表
|
||||
- Description: 查询所有国家和业务线信息列表
|
||||
- ID: 6
|
||||
- HTTP METHOD: GET
|
||||
- Endpoint: `/sup/legal/country/list.json`
|
||||
- Input Parameters:
|
||||
|名称|描述|类型(长度)|必选|备注|
|
||||
| :- | :- | :-: | :- | :- |
|
||||
- Returns:
|
||||
|名称|描述|类型(长度)|必选|备注|
|
||||
| :- | :- | :-: | :- | :- |
|
||||
|success|业务处理成功true,否则false|boolean|√|只判断这个属性即可|
|
||||
|message|错误信息,可以用来提示|string|√||
|
||||
|code|返回状态码|string|√||
|
||||
|data|所有数据列表|list of object|√||
|
||||
|-> country|国家code|string|√||
|
||||
|-> business|业务线code|string|√||
|
||||
|-> dataType|数据类型|string|√||
|
||||
|-> businessName|业务线名|string|√||
|
||||
|-> countryName|国家名|string|√||
|
||||
|-> businessNameEn|业务线名(英文)|string|√||
|
||||
- 调用法务中台antlaw接口
|
||||
- ID: 7
|
||||
- 国家/区域导游详情 & 法律意见详情 查询
|
||||
- Description:根据国家code查询国家/区域导游信息详情
|
||||
- ID: 8
|
||||
- HTTP METHOD: GET
|
||||
- Endpoint: `/contract/country/navigate.json`
|
||||
- Input Parameters:
|
||||
|名称|描述|类型(长度)|必选|备注|
|
||||
| :- | :- | :-: | :- | :- |
|
||||
|countryCode|国家code|string|√||
|
||||
- Returns:
|
||||
|名称|描述|类型(长度)|必选|备注|
|
||||
| :- | :- | :-: | :- | :- |
|
||||
|success|业务处理成功true,否则false|boolean|√|只判断这个属性即可|
|
||||
|message|错误信息,可以用来提示|string|√||
|
||||
|code|返回状态码|string|√||
|
||||
|data|国家/区域导游详情|object|√||
|
||||
|-> country|||||
|
||||
|-> -> id|id|integer|√||
|
||||
|-> -> country|国家code|string|√||
|
||||
|-> -> countryName|国家中文名称|string|√||
|
||||
|-> -> countryNameEn|国家英文名称|string|√||
|
||||
|-> -> content|国家导游中文详情json数组,具体格式见下示例|list of object|√||
|
||||
|-> -> -> title|标题|object|√||
|
||||
|-> -> -> -> title|中文标题|string|||
|
||||
|-> -> -> -> titleEn|英文标题|string|||
|
||||
|-> -> -> contentList|标题下面的文字描述列表|list of object|√||
|
||||
|-> -> -> -> detail|内容中文详情|string|√||
|
||||
|-> -> -> -> detailEn|内容英文详情|string|√||
|
||||
|-> -> -> -> url|超链接|string|||
|
||||
|-> legal|法务信息|object|||
|
||||
|-> -> country|国家code|string|√||
|
||||
|-> -> businessList|业务线列表|list of object|||
|
||||
|-> -> -> id|id|integer||新增时不传,修改时传递|
|
||||
|-> -> -> business|业务线code|string|√||
|
||||
|-> -> -> businessName|业务线中文名称|string|√||
|
||||
|-> -> -> businessNameEn|业务线英文名称|string|√||
|
||||
|-> -> -> content|业务线json,具体如下|object|√||
|
||||
|-> -> -> -> detailEn|具体的详情英文内容|string|√||
|
||||
|-> -> -> -> detail|具体的详情内容|string|√||
|
||||
- 国家/区域导游列表分页查询
|
||||
- Description: 分页查询国家/区域列表
|
||||
- ID: 9
|
||||
- HTTP METHOD: GET
|
||||
- Endpoint: `/contract/country/list.json`
|
||||
- Input Parameters:
|
||||
|名称|描述|类型(长度)|必选|备注|
|
||||
| :- | :- | :-: | :- | :- |
|
||||
|pageSize|分页大小|integer|√|>=1|
|
||||
|pageNum|分页大小|integer|√|>=1|
|
||||
|country|国家code|string|||
|
||||
|business|业务线code|string|||
|
||||
- Returns:
|
||||
|名称|描述|类型(长度)|必选|备注|
|
||||
| :- | :- | :-: | :- | :- |
|
||||
|success|业务处理成功true,否则false|boolean|√|只判断这个属性即可|
|
||||
|message|错误信息,可以用来提示|string|√||
|
||||
|code|返回状态码|string|√||
|
||||
|data|国家/区域导游详情|list of object|√||
|
||||
|-> id|id|integer|√||
|
||||
|-> country|国家code|string|√||
|
||||
|-> countryName|国家中文名称|string|√||
|
||||
|-> countryNameEn|国家英文名称|string|√||
|
||||
|-> gmtCreate|创建时间|string|√||
|
||||
|-> gmtModified|更新时间|string|√||
|
||||
|total|数据总量|integer|√||
|
||||
1
tests/data/embedding/3.query.md
Normal file
1
tests/data/embedding/3.query.md
Normal file
|
|
@ -0,0 +1 @@
|
|||
根据国家code查询国家业务线列表
|
||||
0
tests/metagpt/rag/__init__.py
Normal file
0
tests/metagpt/rag/__init__.py
Normal file
55
tests/metagpt/rag/test_large_pdf.py
Normal file
55
tests/metagpt/rag/test_large_pdf.py
Normal file
|
|
@ -0,0 +1,55 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.config2 import Config
|
||||
from metagpt.const import TEST_DATA_PATH
|
||||
from metagpt.rag.engines import SimpleEngine
|
||||
from metagpt.rag.factories.embedding import RAGEmbeddingFactory
|
||||
from metagpt.utils.common import aread
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.parametrize(
|
||||
("knowledge_filename", "query_filename", "answer_filename"),
|
||||
[
|
||||
(
|
||||
TEST_DATA_PATH / "embedding/2.knowledge.md",
|
||||
TEST_DATA_PATH / "embedding/2.query.md",
|
||||
TEST_DATA_PATH / "embedding/2.answer.md",
|
||||
),
|
||||
(
|
||||
TEST_DATA_PATH / "embedding/3.knowledge.md",
|
||||
TEST_DATA_PATH / "embedding/3.query.md",
|
||||
TEST_DATA_PATH / "embedding/3.answer.md",
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_large_pdf(knowledge_filename, query_filename, answer_filename):
|
||||
Config.default(reload=True) # `config.embedding.model = "text-embedding-ada-002"` changes the cache.
|
||||
|
||||
engine = SimpleEngine.from_docs(
|
||||
input_files=[knowledge_filename],
|
||||
)
|
||||
|
||||
query = await aread(filename=query_filename)
|
||||
rsp = await engine.aretrieve(query)
|
||||
assert rsp
|
||||
|
||||
config = Config.default()
|
||||
config.embedding.model = "text-embedding-ada-002"
|
||||
factory = RAGEmbeddingFactory(config)
|
||||
embedding = factory.get_rag_embedding()
|
||||
answer = await aread(filename=answer_filename)
|
||||
answer_embedding = await embedding.aget_text_embedding(answer)
|
||||
similarity = 0
|
||||
for i in rsp:
|
||||
rsp_embedding = await embedding.aget_query_embedding(i.text)
|
||||
v = embedding.similarity(answer_embedding, rsp_embedding)
|
||||
similarity = max(similarity, v)
|
||||
|
||||
print(similarity)
|
||||
assert similarity > 0.9
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
32
tests/metagpt/tools/libs/test_index_repo.py
Normal file
32
tests/metagpt/tools/libs/test_index_repo.py
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
import shutil
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.const import DEFAULT_WORKSPACE_ROOT, TEST_DATA_PATH
|
||||
from metagpt.tools.libs.index_repo import IndexRepo
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(("path", "query"), [(TEST_DATA_PATH / "requirements", "业务线")])
|
||||
async def test_index_repo(path, query):
|
||||
index_path = DEFAULT_WORKSPACE_ROOT / ".index"
|
||||
repo = IndexRepo(persist_path=str(index_path), root_path=str(path), min_token_count=0)
|
||||
await repo.add([path])
|
||||
await repo.add([path])
|
||||
assert index_path.exists()
|
||||
|
||||
rsp = await repo.search(query)
|
||||
assert rsp
|
||||
|
||||
repo2 = IndexRepo(persist_path=str(index_path), root_path=str(path), min_token_count=0)
|
||||
rsp2 = await repo2.search(query)
|
||||
assert rsp2
|
||||
|
||||
merged_rsp = await repo.merge(query=query, indices_list=[rsp, rsp2])
|
||||
assert merged_rsp
|
||||
|
||||
shutil.rmtree(index_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
Loading…
Add table
Add a link
Reference in a new issue