Merge branch 'feature/rfc258' into 'mgx_ops'

Feat: RFC-258-Editor.search设计方案对应的Index Repo

See merge request pub/MetaGPT!362
This commit is contained in:
林义章 2024-09-04 11:36:29 +00:00
commit ac29811c2f
18 changed files with 648 additions and 20 deletions

View file

@ -2,7 +2,6 @@ from abc import abstractmethod
from typing import Optional, Union
from metagpt.base.base_serialization import BaseSerialization
from metagpt.schema import Message
class BaseRole(BaseSerialization):
@ -25,13 +24,13 @@ class BaseRole(BaseSerialization):
raise NotImplementedError
@abstractmethod
async def react(self) -> Message:
async def react(self) -> "Message":
"""Entry to one of three strategies by which Role reacts to the observed Message."""
@abstractmethod
async def run(self, with_message: Optional[Union[str, Message, list[str]]] = None) -> Optional[Message]:
async def run(self, with_message: Optional[Union[str, "Message", list[str]]] = None) -> Optional["Message"]:
"""Observe, and think and act based on the results of the observation."""
@abstractmethod
def get_memories(self, k: int = 0) -> list[Message]:
def get_memories(self, k: int = 0) -> list["Message"]:
"""Return the most recent k memories of this role."""

View file

@ -20,11 +20,13 @@ class EmbeddingConfig(YamlModel):
---------
api_type: "openai"
api_key: "YOU_API_KEY"
dimensions: "YOUR_MODEL_DIMENSIONS"
api_type: "azure"
api_key: "YOU_API_KEY"
base_url: "YOU_BASE_URL"
api_version: "YOU_API_VERSION"
dimensions: "YOUR_MODEL_DIMENSIONS"
api_type: "gemini"
api_key: "YOU_API_KEY"
@ -32,6 +34,7 @@ class EmbeddingConfig(YamlModel):
api_type: "ollama"
base_url: "YOU_BASE_URL"
model: "YOU_MODEL"
dimensions: "YOUR_MODEL_DIMENSIONS"
"""
api_type: Optional[EmbeddingType] = None
@ -41,6 +44,7 @@ class EmbeddingConfig(YamlModel):
model: Optional[str] = None
embed_batch_size: Optional[int] = None
dimensions: Optional[int] = None # output dimension of embedding model
@field_validator("api_type", mode="before")
@classmethod

View file

@ -27,7 +27,6 @@ from metagpt.configs.llm_config import LLMConfig
from metagpt.const import IMAGES, LLM_API_TIMEOUT, USE_CONFIG_TIMEOUT
from metagpt.logs import logger
from metagpt.provider.constant import MULTI_MODAL_MODELS
from metagpt.schema import Message
from metagpt.utils.common import log_and_reraise
from metagpt.utils.cost_manager import CostManager, Costs
from metagpt.utils.token_counter import TOKEN_MAX
@ -80,7 +79,7 @@ class BaseLLM(ABC):
def support_image_input(self) -> bool:
return any([m in self.config.model for m in MULTI_MODAL_MODELS])
def format_msg(self, messages: Union[str, Message, list[dict], list[Message], list[str]]) -> list[dict]:
def format_msg(self, messages: Union[str, "Message", list[dict], list["Message"], list[str]]) -> list[dict]:
"""convert messages to list[dict]."""
from metagpt.schema import Message
@ -173,7 +172,9 @@ class BaseLLM(ABC):
context.append(self._assistant_msg(rsp_text))
return self._extract_assistant_rsp(context)
async def aask_code(self, messages: Union[str, Message, list[dict]], timeout=USE_CONFIG_TIMEOUT, **kwargs) -> dict:
async def aask_code(
self, messages: Union[str, "Message", list[dict]], timeout=USE_CONFIG_TIMEOUT, **kwargs
) -> dict:
raise NotImplementedError
@abstractmethod

View file

@ -22,7 +22,6 @@ from metagpt.const import USE_CONFIG_TIMEOUT
from metagpt.logs import log_llm_stream, logger
from metagpt.provider.base_llm import BaseLLM
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.schema import Message
class GeminiGenerativeModel(GenerativeModel):
@ -73,7 +72,7 @@ class GeminiLLM(BaseLLM):
def _system_msg(self, msg: str) -> dict[str, str]:
return {"role": "user", "parts": [msg]}
def format_msg(self, messages: Union[str, Message, list[dict], list[Message], list[str]]) -> list[dict]:
def format_msg(self, messages: Union[str, "Message", list[dict], list["Message"], list[str]]) -> list[dict]:
"""convert messages to list[dict]."""
from metagpt.schema import Message

View file

@ -2,7 +2,8 @@
import json
import os
from typing import Any, Optional, Union
from pathlib import Path
from typing import Any, List, Optional, Set, Union
import fsspec
from llama_index.core import SimpleDirectoryReader
@ -78,6 +79,7 @@ class SimpleEngine(RetrieverQueryEngine):
callback_manager=callback_manager,
)
self._transformations = transformations or self._default_transformations()
self._filenames = set()
@classmethod
def from_docs(
@ -192,11 +194,11 @@ class SimpleEngine(RetrieverQueryEngine):
self._try_reconstruct_obj(nodes)
return nodes
def add_docs(self, input_files: list[str]):
def add_docs(self, input_files: List[Union[str, Path]]):
"""Add docs to retriever. retriever must has add_nodes func."""
self._ensure_retriever_modifiable()
documents = SimpleDirectoryReader(input_files=input_files).load_data()
documents = SimpleDirectoryReader(input_files=[str(i) for i in input_files]).load_data()
self._fix_document_metadata(documents)
nodes = run_transformations(documents, transformations=self._transformations)
@ -227,6 +229,24 @@ class SimpleEngine(RetrieverQueryEngine):
return self.retriever.clear(**kwargs)
def delete_docs(self, input_files: List[Union[str, Path]]):
"""Delete documents from the index and document store.
Args:
input_files (List[Union[str, Path]]): A list of file paths or file names to be deleted.
Raises:
NotImplementedError: If the method is not implemented.
"""
exists_filenames = set()
filenames = {str(i) for i in input_files}
for doc_id, info in self.retriever._index.ref_doc_info.items():
if info.metadata.get("file_path") in filenames:
exists_filenames.add(doc_id)
for doc_id in exists_filenames:
self.retriever._index.delete_ref_doc(doc_id, delete_from_docstore=True)
@staticmethod
def get_obj_nodes(objs: Optional[list[RAGObject]] = None) -> list[ObjectNode]:
"""Converts a list of RAGObjects to a list of ObjectNodes."""
@ -333,3 +353,7 @@ class SimpleEngine(RetrieverQueryEngine):
@staticmethod
def _default_transformations():
return [SentenceSplitter()]
@property
def filenames(self) -> Set[str]:
return self._filenames

View file

@ -5,9 +5,6 @@ from typing import Any, Optional
from llama_index.core.embeddings import BaseEmbedding
from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding
from llama_index.embeddings.gemini import GeminiEmbedding
from llama_index.embeddings.ollama import OllamaEmbedding
from llama_index.embeddings.openai import OpenAIEmbedding
from metagpt.config2 import Config
from metagpt.configs.embedding_config import EmbeddingType
@ -49,7 +46,9 @@ class RAGEmbeddingFactory(GenericFactory):
raise TypeError("To use RAG, please set your embedding in config2.yaml.")
def _create_openai(self) -> OpenAIEmbedding:
def _create_openai(self) -> "OpenAIEmbedding":
from llama_index.embeddings.openai import OpenAIEmbedding
params = dict(
api_key=self.config.embedding.api_key or self.config.llm.api_key,
api_base=self.config.embedding.base_url or self.config.llm.base_url,
@ -70,7 +69,9 @@ class RAGEmbeddingFactory(GenericFactory):
return AzureOpenAIEmbedding(**params)
def _create_gemini(self) -> GeminiEmbedding:
def _create_gemini(self) -> "GeminiEmbedding":
from llama_index.embeddings.gemini import GeminiEmbedding
params = dict(
api_key=self.config.embedding.api_key,
api_base=self.config.embedding.base_url,
@ -80,7 +81,9 @@ class RAGEmbeddingFactory(GenericFactory):
return GeminiEmbedding(**params)
def _create_ollama(self) -> OllamaEmbedding:
def _create_ollama(self) -> "OllamaEmbedding":
from llama_index.embeddings.ollama import OllamaEmbedding
params = dict(
base_url=self.config.embedding.base_url,
)

View file

@ -13,7 +13,6 @@ from llama_index.core.llms.callbacks import llm_completion_callback
from pydantic import Field
from metagpt.config2 import Config
from metagpt.llm import LLM
from metagpt.provider.base_llm import BaseLLM
from metagpt.utils.async_helper import NestAsyncio
from metagpt.utils.token_counter import TOKEN_MAX
@ -79,4 +78,6 @@ class RAGLLM(CustomLLM):
def get_rag_llm(model_infer: BaseLLM = None) -> RAGLLM:
"""Get llm that can be used by LlamaIndex."""
from metagpt.llm import LLM
return RAGLLM(model_infer=model_infer or LLM())

View file

@ -0,0 +1,264 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import json
from pathlib import Path
from typing import Dict, List, Optional, Set, Union
import tiktoken
from llama_index.core.base.embeddings.base import BaseEmbedding
from llama_index.core.schema import NodeWithScore
from pydantic import BaseModel, Field, model_validator
from metagpt.config2 import Config
from metagpt.logs import logger
from metagpt.rag.engines import SimpleEngine
from metagpt.rag.factories.embedding import RAGEmbeddingFactory
from metagpt.rag.schema import FAISSIndexConfig, FAISSRetrieverConfig, LLMRankerConfig
from metagpt.utils.common import aread, awrite, generate_fingerprint, list_files
from metagpt.utils.repo_to_markdown import is_text_file
class TextScore(BaseModel):
filename: str
text: str
score: Optional[float] = None
class IndexRepo(BaseModel):
persist_path: str # The persist path of the index repo, {DEFAULT_WORKSPACE_ROOT}/.index/{chat_id or 'uploads'}/
root_path: str # `/data/uploads` or r`/data/chats/\d+`, the root path of files indexed by the index repo.
fingerprint_filename: str = "fingerprint.json"
model: Optional[str] = None
min_token_count: int = 10000
max_token_count: int = 100000000
recall_count: int = 5
embedding: Optional[BaseEmbedding] = Field(default=None, exclude=True)
fingerprints: Dict[str, str] = Field(default_factory=dict)
@model_validator(mode="after")
def _update_fingerprints(self) -> "IndexRepo":
"""Load fingerprints from the fingerprint file if not already loaded.
Returns:
IndexRepo: The updated IndexRepo instance.
"""
if not self.fingerprints:
filename = Path(self.persist_path) / self.fingerprint_filename
if not filename.exists():
return self
with open(str(filename), "r") as reader:
self.fingerprints = json.load(reader)
return self
async def search(
self, query: str, filenames: Optional[List[Path]] = None
) -> Optional[List[Union[NodeWithScore, TextScore]]]:
"""Search for documents related to the given query.
Args:
query (str): The search query.
filenames (Optional[List[Path]]): A list of filenames to filter the search.
Returns:
Optional[List[Union[NodeWithScore, TextScore]]]: A list of search results containing NodeWithScore or TextScore.
"""
encoding = tiktoken.get_encoding("cl100k_base")
result: List[Union[NodeWithScore, TextScore]] = []
filenames, _ = await self._filter(filenames)
filter_filenames = set()
for i in filenames:
content = await aread(filename=i)
token_count = len(encoding.encode(content))
if not self._is_buildable(token_count):
result.append(TextScore(filename=str(i), text=content))
continue
file_fingerprint = generate_fingerprint(content)
if self.fingerprints.get(str(i)) != file_fingerprint:
logger.error(f'file: "{i}" changed but not indexed')
continue
filter_filenames.add(str(i))
nodes = await self._search(query=query, filters=filter_filenames)
return result + nodes
async def merge(
self, query: str, indices_list: List[List[Union[NodeWithScore, TextScore]]]
) -> List[Union[NodeWithScore, TextScore]]:
"""Merge results from multiple indices based on the query.
Args:
query (str): The search query.
indices_list (List[List[Union[NodeWithScore, TextScore]]]): A list of result lists from different indices.
Returns:
List[Union[NodeWithScore, TextScore]]: A list of merged results sorted by similarity.
"""
if not self.embedding:
config = Config.default()
if self.model:
config.embedding.model = self.model
factory = RAGEmbeddingFactory(config)
self.embedding = factory.get_rag_embedding()
scores = []
query_embedding = await self.embedding.aget_text_embedding(query)
flat_nodes = [node for indices in indices_list for node in indices]
for i in flat_nodes:
text_embedding = await self.embedding.aget_text_embedding(i.text)
similarity = self.embedding.similarity(query_embedding, text_embedding)
scores.append((similarity, i))
scores.sort(key=lambda x: x[0], reverse=True)
return [i[1] for i in scores][: self.recall_count]
async def add(self, paths: List[Path]):
"""Add new documents to the index.
Args:
paths (List[Path]): A list of paths to the documents to be added.
"""
encoding = tiktoken.get_encoding("cl100k_base")
filenames, _ = await self._filter(paths)
filter_filenames = []
delete_filenames = []
for i in filenames:
content = await aread(filename=i)
if not self._is_fingerprint_changed(filename=i, content=content):
continue
token_count = len(encoding.encode(content))
if self._is_buildable(token_count):
filter_filenames.append(i)
logger.debug(f"{i} is_buildable: {token_count}, {self.min_token_count}~{self.max_token_count}")
else:
delete_filenames.append(i)
logger.debug(f"{i} not is_buildable: {token_count}, {self.min_token_count}~{self.max_token_count}")
await self._add_batch(filenames=filter_filenames, delete_filenames=delete_filenames)
async def _add_batch(self, filenames: List[Union[str, Path]], delete_filenames: List[Union[str, Path]]):
"""Add and remove documents in a batch operation.
Args:
filenames (List[Union[str, Path]]): List of filenames to add.
delete_filenames (List[Union[str, Path]]): List of filenames to delete.
"""
if not filenames:
return
logger.info(f"update index repo, add {filenames}, remove {delete_filenames}")
engine = None
if Path(self.persist_path).exists():
logger.debug(f"load index from {self.persist_path}")
engine = SimpleEngine.from_index(
index_config=FAISSIndexConfig(persist_path=self.persist_path),
retriever_configs=[FAISSRetrieverConfig()],
)
try:
engine.delete_docs(filenames + delete_filenames)
logger.debug(f"delete docs {filenames + delete_filenames}")
engine.add_docs(input_files=filenames)
logger.debug(f"add docs {filenames}")
except NotImplementedError as e:
logger.debug(f"{e}")
filenames = list(set([str(i) for i in filenames] + list(self.fingerprints.keys())))
engine = None
logger.info(f"{e}. Rebuild all.")
if not engine:
engine = SimpleEngine.from_docs(
input_files=[str(i) for i in filenames],
retriever_configs=[FAISSRetrieverConfig()],
ranker_configs=[LLMRankerConfig()],
)
logger.debug(f"add docs {filenames}")
engine.persist(persist_dir=self.persist_path)
for i in filenames:
content = await aread(i)
fp = generate_fingerprint(content)
self.fingerprints[str(i)] = fp
await awrite(filename=Path(self.persist_path) / self.fingerprint_filename, data=json.dumps(self.fingerprints))
def __str__(self):
"""Return a string representation of the IndexRepo.
Returns:
str: The filename of the index repository.
"""
return f"{self.persist_path}"
def _is_buildable(self, token_count: int) -> bool:
"""Check if the token count is within the buildable range.
Args:
token_count (int): The number of tokens in the content.
Returns:
bool: True if buildable, False otherwise.
"""
if token_count < self.min_token_count or token_count > self.max_token_count:
return False
return True
async def _filter(self, filenames: Optional[List[Union[str, Path]]] = None) -> (List[Path], List[Path]):
"""Filter the provided filenames to only include valid text files.
Args:
filenames (Optional[List[Union[str, Path]]]): List of filenames to filter.
Returns:
Tuple[List[Path], List[Path]]: A tuple containing a list of valid pathnames and a list of excluded paths.
"""
root_path = Path(self.root_path).absolute()
if not filenames:
filenames = [root_path]
pathnames = []
excludes = []
for i in filenames:
path = Path(i).absolute()
if not path.is_relative_to(root_path):
excludes.append(path)
logger.debug(f"{path} not is_relative_to {root_path})")
continue
if not path.is_dir():
is_text, _ = await is_text_file(path)
if is_text:
pathnames.append(path)
continue
subfiles = list_files(path)
for j in subfiles:
is_text, _ = await is_text_file(j)
if is_text:
pathnames.append(j)
logger.debug(f"{pathnames}, excludes:{excludes})")
return pathnames, excludes
async def _search(self, query: str, filters: Set[str]) -> List[NodeWithScore]:
"""Perform a search for the given query using the index.
Args:
query (str): The search query.
filters (Set[str]): A set of filenames to filter the search results.
Returns:
List[NodeWithScore]: A list of nodes with scores matching the query.
"""
if not Path(self.persist_path).exists():
return []
engine = SimpleEngine.from_index(
index_config=FAISSIndexConfig(persist_path=self.persist_path), retriever_configs=[FAISSRetrieverConfig()]
)
rsp = await engine.aretrieve(query)
return [i for i in rsp if i.metadata.get("file_path") in filters]
def _is_fingerprint_changed(self, filename: Union[str, Path], content: str) -> bool:
"""Check if the fingerprint of the given document content has changed.
Args:
filename (Union[str, Path]): The filename of the document.
content (str): The content of the document.
Returns:
bool: True if the fingerprint has changed, False otherwise.
"""
old_fp = self.fingerprints.get(str(filename))
if not old_fp:
return True
fp = generate_fingerprint(content)
return old_fp != fp

View file

@ -16,6 +16,7 @@ import base64
import contextlib
import csv
import functools
import hashlib
import importlib
import inspect
import json
@ -889,7 +890,7 @@ async def get_mime_type(filename: str | Path, force_read: bool = False) -> str:
}
try:
stdout, stderr, _ = await shell_execute(f"file --mime-type {str(filename)}")
stdout, stderr, _ = await shell_execute(f"file --mime-type '{str(filename)}'")
if stderr:
logger.debug(f"file:{filename}, error:{stderr}")
return guess_mime_type
@ -1175,3 +1176,23 @@ def rectify_pathname(path: Union[str, Path], default_filename: str) -> Path:
else:
output_pathname.parent.mkdir(parents=True, exist_ok=True)
return output_pathname
def generate_fingerprint(text: str) -> str:
"""
Generate a fingerprint for the given text
Args:
text (str): The text for which the fingerprint needs to be generated
Returns:
str: The fingerprint value of the text
"""
text_bytes = text.encode("utf-8")
# calculate SHA-256 hash
sha256 = hashlib.sha256()
sha256.update(text_bytes)
fingerprint = sha256.hexdigest()
return fingerprint