cr修改,单测完善

This commit is contained in:
liuminhui 2024-07-22 17:10:29 +08:00
parent 79334de5a4
commit 758acf8ba6
18 changed files with 371 additions and 178 deletions

View file

@ -12,7 +12,8 @@ from typing import Dict, Iterable, List, Literal, Optional
from pydantic import BaseModel, model_validator
from metagpt.configs.browser_config import BrowserConfig
from metagpt.configs.embedding_config import EmbeddingConfig, OmniParseConfig
from metagpt.configs.embedding_config import EmbeddingConfig
from metagpt.configs.file_parser_config import OmniParseConfig
from metagpt.configs.llm_config import LLMConfig, LLMType
from metagpt.configs.mermaid_config import MermaidConfig
from metagpt.configs.redis_config import RedisConfig

View file

@ -52,8 +52,3 @@ class EmbeddingConfig(YamlModel):
if v == "":
return None
return v
class OmniParseConfig(YamlModel):
api_key: str = ""
base_url: str = ""

View file

@ -0,0 +1,6 @@
from metagpt.utils.yaml_model import YamlModel
class OmniParseConfig(YamlModel):
api_key: str = ""
base_url: str = ""

View file

@ -14,6 +14,7 @@ from llama_index.core.llms import LLM
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core.postprocessor.types import BaseNodePostprocessor
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.core.readers.base import BaseReader
from llama_index.core.response_synthesizers import (
BaseSynthesizer,
get_response_synthesizer,
@ -27,7 +28,6 @@ from llama_index.core.schema import (
QueryType,
TransformComponent,
)
from llama_parse import ResultType
from metagpt.config2 import config
from metagpt.rag.factories import (
@ -38,7 +38,7 @@ from metagpt.rag.factories import (
get_retriever,
)
from metagpt.rag.interface import NoEmbedding, RAGObject
from metagpt.rag.parser.omniparse.parse import OmniParse
from metagpt.rag.parser import OmniParse
from metagpt.rag.retrievers.base import ModifiableRAGRetriever, PersistableRAGRetriever
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
from metagpt.rag.schema import (
@ -46,7 +46,10 @@ from metagpt.rag.schema import (
BaseRankerConfig,
BaseRetrieverConfig,
BM25RetrieverConfig,
ObjectNode, OmniParseOptions, OmniParseType,
ObjectNode,
OmniParseOptions,
OmniParseType,
ParseResultType,
)
from metagpt.utils.common import import_class
@ -76,18 +79,6 @@ class SimpleEngine(RetrieverQueryEngine):
)
self._transformations = transformations or self._default_transformations()
@classmethod
def get_file_extractor(cls, file_type: str):
if not config.omniparse.base_url:
return
parser = OmniParse(
api_key=config.omniparse.api_key,
base_url=config.omniparse.base_url,
parse_options=OmniParseOptions(parse_type=OmniParseType.PDF, result_type=ResultType.MD)
)
file_extractor = {file_type: parser}
return file_extractor
@classmethod
def from_docs(
cls,
@ -115,7 +106,7 @@ class SimpleEngine(RetrieverQueryEngine):
if not input_dir and not input_files:
raise ValueError("Must provide either `input_dir` or `input_files`.")
file_extractor = cls.get_file_extractor(file_type=".pdf")
file_extractor = cls._get_file_extractor(file_type=".pdf")
documents = SimpleDirectoryReader(
input_dir=input_dir, input_files=input_files, file_extractor=file_extractor
).load_data()
@ -319,3 +310,31 @@ class SimpleEngine(RetrieverQueryEngine):
@staticmethod
def _default_transformations():
return [SentenceSplitter()]
@staticmethod
def _get_file_extractor(file_type: str = None) -> dict[str:BaseReader]:
"""
Get the file extractor for a specified file type.
If no file type is provided, return all available extractors.
Currently, only OmniParse PDF extraction is supported.
Args:
file_type: The type of file for which the extractor is needed. Defaults to None.
Returns:
dict[file_type: BaseReader]
"""
file_extractor_mapping: dict[str:BaseReader] = {}
if config.omniparse.base_url:
pdf_parser = OmniParse(
api_key=config.omniparse.api_key,
base_url=config.omniparse.base_url,
parse_options=OmniParseOptions(parse_type=OmniParseType.PDF, result_type=ParseResultType.MD),
)
file_extractor_mapping[".pdf"] = pdf_parser
if file_type:
file_extractor = file_extractor_mapping.get(file_type)
return {file_type: file_extractor} if file_extractor else {}
return file_extractor_mapping

View file

@ -0,0 +1,3 @@
from metagpt.rag.parser.omniparse import OmniParse
__all__ = ["OmniParse"]

View file

@ -1,28 +1,31 @@
import asyncio
from fileinput import FileInput
from pathlib import Path
from typing import List, Union, Optional
from typing import List, Optional, Union
from llama_index.core import Document
from llama_index.core.async_utils import run_jobs
from llama_index.core.readers.base import BaseReader
from llama_parse import ResultType
from metagpt.rag.parser.omniparse.client import OmniParseClient
from metagpt.rag.schema import OmniParseOptions, OmniParseType
from metagpt.logs import logger
from metagpt.rag.schema import OmniParseOptions, OmniParseType
from metagpt.utils.async_helper import NestAsyncio
from metagpt.utils.omniparse_client import OmniParseClient
class OmniParse(BaseReader):
"""OmniParse"""
def __init__(
self,
api_key=None,
base_url="http://localhost:8000",
parse_options: OmniParseOptions = None
self, api_key: str = None, base_url: str = "http://localhost:8000", parse_options: OmniParseOptions = None
):
"""
Args:
api_key: Default None, can be used for authentication later.
base_url: OmniParse Base URL for the API.
parse_options: Optional settings for OmniParse. Default is OmniParseOptions with default values.
"""
self.parse_options = parse_options or OmniParseOptions()
self.omniparse_client = OmniParseClient(api_key, base_url, max_timeout=self.parse_options.max_timeout)
@ -47,20 +50,32 @@ class OmniParse(BaseReader):
self.parse_options.result_type = result_type
async def _aload_data(
self,
file_path: Union[str, bytes, Path],
extra_info: Optional[dict] = None,
self,
file_path: Union[str, bytes, Path],
extra_info: Optional[dict] = None,
) -> List[Document]:
"""
Load data from the input file_path.
Args:
file_path: File path or file byte data.
extra_info: Optional dictionary containing additional information.
Returns:
List[Document]
"""
try:
if self.parse_type == OmniParseType.PDF:
# 目前先只支持pdf解析
# pdf parse
parsed_result = await self.omniparse_client.parse_pdf(file_path)
else:
# other parse use omniparse_client.parse_document
# For compatible byte data, additional filename is required
extra_info = extra_info or {}
filename = extra_info.get("filename") # 兼容字节数据要额外传filename
filename = extra_info.get("filename")
parsed_result = await self.omniparse_client.parse_document(file_path, bytes_filename=filename)
# 获取指定的结构数据
# Get the specified structured data based on result_type
content = getattr(parsed_result, self.result_type)
docs = [
Document(
@ -75,26 +90,51 @@ class OmniParse(BaseReader):
return docs
async def aload_data(
self,
file_path: Union[List[FileInput], FileInput],
extra_info: Optional[dict] = None,
self,
file_path: Union[List[FileInput], FileInput],
extra_info: Optional[dict] = None,
) -> List[Document]:
"""
Load data from the input file_path.
Args:
file_path: File path or file byte data.
extra_info: Optional dictionary containing additional information.
Notes:
This method ultimately calls _aload_data for processing.
Returns:
List[Document]
"""
docs = []
if isinstance(file_path, (str, bytes, Path)):
# 处理单个
# Processing single file
docs = await self._aload_data(file_path, extra_info)
elif isinstance(file_path, list):
# 并发处理多个
# Concurrently process multiple files
parse_jobs = [self._aload_data(file_item, extra_info) for file_item in file_path]
doc_ret_list = await run_jobs(jobs=parse_jobs, workers=self.parse_options.num_workers)
docs = [doc for docs in doc_ret_list for doc in docs]
return docs
def load_data(
self,
file_path: Union[List[FileInput], FileInput],
extra_info: Optional[dict] = None,
self,
file_path: Union[List[FileInput], FileInput],
extra_info: Optional[dict] = None,
) -> List[Document]:
"""Load data from the input path."""
NestAsyncio.apply_once() # 兼容异步嵌套调用
"""
Load data from the input file_path.
Args:
file_path: File path or file byte data.
extra_info: Optional dictionary containing additional information.
Notes:
This method ultimately calls aload_data for processing.
Returns:
List[Document]
"""
NestAsyncio.apply_once() # Ensure compatibility with nested async calls
return asyncio.run(self.aload_data(file_path, extra_info))

View file

@ -1,2 +0,0 @@
from .client import OmniParseClient
from .parse import OmniParse

View file

@ -8,7 +8,6 @@ from llama_index.core.embeddings import BaseEmbedding
from llama_index.core.indices.base import BaseIndex
from llama_index.core.schema import TextNode
from llama_index.core.vector_stores.types import VectorStoreQueryMode
from llama_parse import ResultType
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator
from metagpt.config2 import config
@ -218,23 +217,31 @@ class ObjectNode(TextNode):
class OmniParseType(str, Enum):
"""OmniParse解析类型"""
"""OmniParseType"""
PDF = "PDF"
DOCUMENT = "DOCUMENT"
class ParseResultType(str, Enum):
"""The result type for the parser."""
TXT = "text"
MD = "markdown"
JSON = "json"
class OmniParseOptions(BaseModel):
"""OmniParse可选配置"""
result_type: ResultType = Field(default=ResultType.MD, description="OmniParse解析返回的结果类型")
parse_type: OmniParseType = Field(default=OmniParseType.DOCUMENT, description="OmniParse解析类型默认文档类型")
max_timeout: Optional[int] = Field(default=120, description="OmniParse服务请求最大超时")
result_type: ParseResultType = Field(default=ParseResultType.MD, description="OmniParse result_type")
parse_type: OmniParseType = Field(default=OmniParseType.DOCUMENT, description="OmniParse parse_type")
max_timeout: Optional[int] = Field(default=120, description="Maximum timeout for OmniParse service requests")
num_workers: int = Field(
default=5,
gt=0,
lt=10,
description="多文件列表时并发请求数量",
description="Number of concurrent requests for multiple files",
)

View file

@ -3,28 +3,35 @@ import os
from pathlib import Path
from typing import Union
import aiofiles
import httpx
from metagpt.rag.schema import OmniParsedResult
from metagpt.utils.common import aread_bin
class OmniParseClient:
"""
OmniParse Server Client
OmniParse API Docs: https://docs.cognitivelab.in/api
This client interacts with the OmniParse server to parse different types of media, documents, and websites.
OmniParse API Documentation: https://docs.cognitivelab.in/api
Attributes:
ALLOWED_DOCUMENT_EXTENSIONS (set): A set of supported document file extensions.
ALLOWED_AUDIO_EXTENSIONS (set): A set of supported audio file extensions.
ALLOWED_VIDEO_EXTENSIONS (set): A set of supported video file extensions.
"""
ALLOWED_DOCUMENT_EXTENSIONS = {".pdf", ".ppt", ".pptx", ".doc", ".docx"}
ALLOWED_AUDIO_EXTENSIONS = {".mp3", ".wav", ".aac"}
ALLOWED_VIDEO_EXTENSIONS = {".mp4", ".mkv", ".avi", ".mov"}
def __init__(self, api_key=None, base_url="http://localhost:8000", max_timeout=120):
def __init__(self, api_key: str = None, base_url: str = "http://localhost:8000", max_timeout: int = 120):
"""
Args:
api_key: 默认 None 后续可用于鉴权
base_url: api 基础url
max_timeout: 请求最大超时时间单位s
api_key: Default None, can be used for authentication later.
base_url: Base URL for the API.
max_timeout: Maximum request timeout in seconds.
"""
self.api_key = api_key
self.base_url = base_url
@ -46,19 +53,20 @@ class OmniParseClient:
**kwargs,
) -> dict:
"""
请求api解析文档
Request OmniParse API to parse a document.
Args:
endpoint (str): API endpoint.
files (dict, optional): 请求文件数据.
params (dict, optional): 查询字符串参数.
data (dict, optional): 请求体数据.
json (dict, optional): 请求json数据.
headers (dict, optional): 请求头数据.
**kwargs: 其他 httpx.AsyncClient.request() 关键字参数
method (str, optional): HTTP method to use. Default is "POST".
files (dict, optional): Files to include in the request.
params (dict, optional): Query string parameters.
data (dict, optional): Form data to include in the request body.
json (dict, optional): JSON data to include in the request body.
headers (dict, optional): HTTP headers to include in the request.
**kwargs: Additional keyword arguments for httpx.AsyncClient.request()
Returns:
dict: 响应的json数据
dict: JSON response data.
"""
url = f"{self.base_url}{endpoint}"
method = method.upper()
@ -80,17 +88,94 @@ class OmniParseClient:
response.raise_for_status()
return response.json()
async def parse_document(self, filelike: Union[str, bytes, Path], bytes_filename: str = None) -> OmniParsedResult:
"""
Parse document-type data (supports ".pdf", ".ppt", ".pptx", ".doc", ".docx").
Args:
filelike: File path or file byte data.
bytes_filename: Filename for byte data, useful for determining MIME type for the HTTP request.
Raises:
ValueError: If the file extension is not allowed.
Returns:
OmniParsedResult: The result of the document parsing.
"""
self.verify_file_ext(filelike, self.ALLOWED_DOCUMENT_EXTENSIONS, bytes_filename)
file_info = await self.get_file_info(filelike, bytes_filename)
resp = await self._request_parse(self.parse_document_endpoint, files={"file": file_info})
data = OmniParsedResult(**resp)
return data
async def parse_pdf(self, filelike: Union[str, bytes, Path]) -> OmniParsedResult:
"""
Parse pdf document.
Args:
filelike: File path or file byte data.
Raises:
ValueError: If the file extension is not allowed.
Returns:
OmniParsedResult: The result of the pdf parsing.
"""
self.verify_file_ext(filelike, {".pdf"})
file_info = await self.get_file_info(filelike, only_bytes=True)
endpoint = f"{self.parse_document_endpoint}/pdf"
resp = await self._request_parse(endpoint=endpoint, files={"file": file_info})
data = OmniParsedResult(**resp)
return data
async def parse_video(self, filelike: Union[str, bytes, Path], bytes_filename: str = None) -> dict:
"""
Parse video-type data (supports ".mp4", ".mkv", ".avi", ".mov").
Args:
filelike: File path or file byte data.
bytes_filename: Filename for byte data, useful for determining MIME type for the HTTP request.
Raises:
ValueError: If the file extension is not allowed.
Returns:
dict: JSON response data.
"""
self.verify_file_ext(filelike, self.ALLOWED_VIDEO_EXTENSIONS, bytes_filename)
file_info = await self.get_file_info(filelike, bytes_filename)
return await self._request_parse(f"{self.parse_media_endpoint}/video", files={"file": file_info})
async def parse_audio(self, filelike: Union[str, bytes, Path], bytes_filename: str = None) -> dict:
"""
Parse audio-type data (supports ".mp3", ".wav", ".aac").
Args:
filelike: File path or file byte data.
bytes_filename: Filename for byte data, useful for determining MIME type for the HTTP request.
Raises:
ValueError: If the file extension is not allowed.
Returns:
dict: JSON response data.
"""
self.verify_file_ext(filelike, self.ALLOWED_AUDIO_EXTENSIONS, bytes_filename)
file_info = await self.get_file_info(filelike, bytes_filename)
return await self._request_parse(f"{self.parse_media_endpoint}/audio", files={"file": file_info})
@staticmethod
def verify_file_ext(filelike: Union[str, bytes, Path], allowed_file_extensions: set, bytes_filename: str = None):
"""
校验文件后缀
Verify the file extension.
Args:
filelike: 文件路径 or 文件字节数据
allowed_file_extensions: 允许的文件扩展名
bytes_filename: 当字节数据时使用这个参数校验
filelike: File path or file byte data.
allowed_file_extensions: Set of allowed file extensions.
bytes_filename: Filename to use for verification when `filelike` is byte data.
Raises:
ValueError
ValueError: If the file extension is not allowed.
Returns:
"""
@ -101,7 +186,7 @@ class OmniParseClient:
verify_file_path = bytes_filename
if not verify_file_path:
# 仅传字节数据时不校验
# Do not verify if only byte data is provided
return
file_ext = os.path.splitext(verify_file_path)[1].lower()
@ -112,28 +197,29 @@ class OmniParseClient:
async def get_file_info(
filelike: Union[str, bytes, Path],
bytes_filename: str = None,
only_bytes=True,
only_bytes: bool = False,
) -> Union[bytes, tuple]:
"""
获取文件字节信息
Get file information.
Args:
filelike: 文件数据
bytes_filename: 通过字节数据上传需要指定文件名称方便获取mime_type
only_bytes: 是否只需要字节数据
filelike: File path or file byte data.
bytes_filename: Filename to use when uploading byte data, useful for determining MIME type.
only_bytes: Whether to return only byte data. Default is False, which returns a tuple.
Raises:
ValueError
ValueError: If bytes_filename is not provided when filelike is bytes or if filelike is not a valid type.
Notes:
由于 parse_document 支持多种文件解析需要上传文件时指定文件的mime_type
Since `parse_document`,`parse_video`, `parse_audio` supports parsing various file types,
the MIME type of the file must be specified when uploading.
Returns:
[bytes, tuple]
Returns: [bytes, tuple]
Returns bytes if only_bytes is True, otherwise returns a tuple (filename, file_bytes, mime_type).
"""
if isinstance(filelike, (str, Path)):
filename = os.path.basename(str(filelike))
async with aiofiles.open(filelike, "rb") as file:
file_bytes = await file.read()
file_bytes = await aread_bin(filelike)
if only_bytes:
return file_bytes
@ -150,60 +236,3 @@ class OmniParseClient:
return bytes_filename, filelike, mime_type
else:
raise ValueError("filelike must be a string (file path) or bytes.")
async def parse_document(self, filelike: Union[str, bytes, Path], bytes_filename: str = None) -> OmniParsedResult:
"""
解析文档类型数据支持 ".pdf", ".ppt", ".pptx", ".doc", ".docs"
Args:
filelike: 文件路径 or 文件字节数据
bytes_filename: 字节数据名称方便获取mime_type 用于httpx请求
Raises
ValueError
Returns:
OmniParsedResult
"""
self.verify_file_ext(filelike, self.ALLOWED_DOCUMENT_EXTENSIONS, bytes_filename)
file_info = await self.get_file_info(filelike, bytes_filename, only_bytes=False)
resp = await self._request_parse(self.parse_document_endpoint, files={"file": file_info})
data = OmniParsedResult(**resp)
return data
async def parse_pdf(self, filelike: Union[str, bytes, Path]) -> OmniParsedResult:
"""
解析PDF文档
Args:
filelike: 文件路径 or 文件字节数据
Raises
ValueError
Returns:
OmniParsedResult
"""
self.verify_file_ext(filelike, {".pdf"})
file_info = await self.get_file_info(filelike)
endpoint = f"{self.parse_document_endpoint}/pdf"
resp = await self._request_parse(endpoint=endpoint, files={"file": file_info})
data = OmniParsedResult(**resp)
return data
async def parse_video(self, filelike: Union[str, bytes, Path], bytes_filename: str = None) -> dict:
"""解析视频"""
self.verify_file_ext(filelike, self.ALLOWED_VIDEO_EXTENSIONS, bytes_filename)
file_info = await self.get_file_info(filelike, bytes_filename, only_bytes=False)
return await self._request_parse(f"{self.parse_media_endpoint}/video", files={"file": file_info})
async def parse_audio(self, filelike: Union[str, bytes, Path], bytes_filename: str = None) -> dict:
"""解析音频"""
self.verify_file_ext(filelike, self.ALLOWED_AUDIO_EXTENSIONS, bytes_filename)
file_info = await self.get_file_info(filelike, bytes_filename, only_bytes=False)
return await self._request_parse(f"{self.parse_media_endpoint}/audio", files={"file": file_info})
async def parse_website(self, url: str) -> dict:
"""
解析网站
fixme:官方api还存在问题
"""
return await self._request_parse(f"{self.parse_website_endpoint}/parse", params={"url": url})