cr修改,单测完善

This commit is contained in:
liuminhui 2024-07-22 17:10:29 +08:00
parent 79334de5a4
commit 758acf8ba6
18 changed files with 371 additions and 178 deletions

View file

@ -59,3 +59,7 @@ iflytek_api_key: "YOUR_API_KEY"
iflytek_api_secret: "YOUR_API_SECRET"
metagpt_tti_url: "YOUR_MODEL_URL"
omniparse:
api_key: "YOUR_API_KEY"
base_url: "YOUR_BASE_URL"

View file

@ -5,8 +5,4 @@ llm:
api_type: "openai" # or azure / ollama / groq etc.
model: "gpt-4-turbo" # or gpt-3.5-turbo
base_url: "https://api.openai.com/v1" # or forward url / other llm url
api_key: "xxxx"
omniparse:
api_key: "your_api_key"
base_url: "http://192.168.50.126:8000"
api_key: "YOUR_API_KEY"

View file

@ -1,18 +1,16 @@
import asyncio
from llama_parse import ResultType
from metagpt.config2 import config
from metagpt.logs import logger
from metagpt.rag.parser.omniparse.client import OmniParseClient
from metagpt.rag.parser.omniparse.parse import OmniParse
from metagpt.rag.schema import OmniParseOptions, OmniParseType
from metagpt.const import EXAMPLE_DATA_PATH
from metagpt.logs import logger
from metagpt.rag.parser import OmniParse
from metagpt.rag.schema import OmniParseOptions, OmniParseType, ParseResultType
from metagpt.utils.omniparse_client import OmniParseClient
TEST_DOCX = EXAMPLE_DATA_PATH / "parse/test01.docx"
TEST_PDF = EXAMPLE_DATA_PATH / "parse/test02.pdf"
TEST_VIDEO = EXAMPLE_DATA_PATH / "parse/test03.mp4"
TEST_AUDIO = EXAMPLE_DATA_PATH / "parse/test04.mp3"
TEST_DOCX = EXAMPLE_DATA_PATH / "omniparse/test01.docx"
TEST_PDF = EXAMPLE_DATA_PATH / "omniparse/test02.pdf"
TEST_VIDEO = EXAMPLE_DATA_PATH / "omniparse/test03.mp4"
TEST_AUDIO = EXAMPLE_DATA_PATH / "omniparse/test04.mp3"
TEST_WEBSITE_URL = "https://github.com/geekan/MetaGPT"
@ -37,10 +35,6 @@ async def omniparse_client_example():
audio_parse_ret = await client.parse_audio(filelike=TEST_AUDIO)
logger.info(audio_parse_ret)
# website fixme:omniparse官方api还存在问题
# website_parse_ret = await client.parse_website(url=TEST_WEBSITE_URL)
# logger.info(website_parse_ret)
async def omniparse_example():
parser = OmniParse(
@ -48,10 +42,10 @@ async def omniparse_example():
base_url=config.omniparse.base_url,
parse_options=OmniParseOptions(
parse_type=OmniParseType.PDF,
result_type=ResultType.MD,
result_type=ParseResultType.MD,
max_timeout=120,
num_workers=3,
)
),
)
ret = parser.load_data(file_path=TEST_PDF)
logger.info(ret)
@ -67,5 +61,5 @@ async def main():
await omniparse_example()
if __name__ == '__main__':
if __name__ == "__main__":
asyncio.run(main())

View file

@ -12,7 +12,8 @@ from typing import Dict, Iterable, List, Literal, Optional
from pydantic import BaseModel, model_validator
from metagpt.configs.browser_config import BrowserConfig
from metagpt.configs.embedding_config import EmbeddingConfig, OmniParseConfig
from metagpt.configs.embedding_config import EmbeddingConfig
from metagpt.configs.file_parser_config import OmniParseConfig
from metagpt.configs.llm_config import LLMConfig, LLMType
from metagpt.configs.mermaid_config import MermaidConfig
from metagpt.configs.redis_config import RedisConfig

View file

@ -52,8 +52,3 @@ class EmbeddingConfig(YamlModel):
if v == "":
return None
return v
class OmniParseConfig(YamlModel):
api_key: str = ""
base_url: str = ""

View file

@ -0,0 +1,6 @@
from metagpt.utils.yaml_model import YamlModel
class OmniParseConfig(YamlModel):
api_key: str = ""
base_url: str = ""

View file

@ -14,6 +14,7 @@ from llama_index.core.llms import LLM
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core.postprocessor.types import BaseNodePostprocessor
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.core.readers.base import BaseReader
from llama_index.core.response_synthesizers import (
BaseSynthesizer,
get_response_synthesizer,
@ -27,7 +28,6 @@ from llama_index.core.schema import (
QueryType,
TransformComponent,
)
from llama_parse import ResultType
from metagpt.config2 import config
from metagpt.rag.factories import (
@ -38,7 +38,7 @@ from metagpt.rag.factories import (
get_retriever,
)
from metagpt.rag.interface import NoEmbedding, RAGObject
from metagpt.rag.parser.omniparse.parse import OmniParse
from metagpt.rag.parser import OmniParse
from metagpt.rag.retrievers.base import ModifiableRAGRetriever, PersistableRAGRetriever
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
from metagpt.rag.schema import (
@ -46,7 +46,10 @@ from metagpt.rag.schema import (
BaseRankerConfig,
BaseRetrieverConfig,
BM25RetrieverConfig,
ObjectNode, OmniParseOptions, OmniParseType,
ObjectNode,
OmniParseOptions,
OmniParseType,
ParseResultType,
)
from metagpt.utils.common import import_class
@ -76,18 +79,6 @@ class SimpleEngine(RetrieverQueryEngine):
)
self._transformations = transformations or self._default_transformations()
@classmethod
def get_file_extractor(cls, file_type: str):
if not config.omniparse.base_url:
return
parser = OmniParse(
api_key=config.omniparse.api_key,
base_url=config.omniparse.base_url,
parse_options=OmniParseOptions(parse_type=OmniParseType.PDF, result_type=ResultType.MD)
)
file_extractor = {file_type: parser}
return file_extractor
@classmethod
def from_docs(
cls,
@ -115,7 +106,7 @@ class SimpleEngine(RetrieverQueryEngine):
if not input_dir and not input_files:
raise ValueError("Must provide either `input_dir` or `input_files`.")
file_extractor = cls.get_file_extractor(file_type=".pdf")
file_extractor = cls._get_file_extractor(file_type=".pdf")
documents = SimpleDirectoryReader(
input_dir=input_dir, input_files=input_files, file_extractor=file_extractor
).load_data()
@ -319,3 +310,31 @@ class SimpleEngine(RetrieverQueryEngine):
@staticmethod
def _default_transformations():
return [SentenceSplitter()]
@staticmethod
def _get_file_extractor(file_type: str = None) -> dict[str:BaseReader]:
"""
Get the file extractor for a specified file type.
If no file type is provided, return all available extractors.
Currently, only OmniParse PDF extraction is supported.
Args:
file_type: The type of file for which the extractor is needed. Defaults to None.
Returns:
dict[file_type: BaseReader]
"""
file_extractor_mapping: dict[str:BaseReader] = {}
if config.omniparse.base_url:
pdf_parser = OmniParse(
api_key=config.omniparse.api_key,
base_url=config.omniparse.base_url,
parse_options=OmniParseOptions(parse_type=OmniParseType.PDF, result_type=ParseResultType.MD),
)
file_extractor_mapping[".pdf"] = pdf_parser
if file_type:
file_extractor = file_extractor_mapping.get(file_type)
return {file_type: file_extractor} if file_extractor else {}
return file_extractor_mapping

View file

@ -0,0 +1,3 @@
from metagpt.rag.parser.omniparse import OmniParse
__all__ = ["OmniParse"]

View file

@ -1,28 +1,31 @@
import asyncio
from fileinput import FileInput
from pathlib import Path
from typing import List, Union, Optional
from typing import List, Optional, Union
from llama_index.core import Document
from llama_index.core.async_utils import run_jobs
from llama_index.core.readers.base import BaseReader
from llama_parse import ResultType
from metagpt.rag.parser.omniparse.client import OmniParseClient
from metagpt.rag.schema import OmniParseOptions, OmniParseType
from metagpt.logs import logger
from metagpt.rag.schema import OmniParseOptions, OmniParseType
from metagpt.utils.async_helper import NestAsyncio
from metagpt.utils.omniparse_client import OmniParseClient
class OmniParse(BaseReader):
"""OmniParse"""
def __init__(
self,
api_key=None,
base_url="http://localhost:8000",
parse_options: OmniParseOptions = None
self, api_key: str = None, base_url: str = "http://localhost:8000", parse_options: OmniParseOptions = None
):
"""
Args:
api_key: Default None, can be used for authentication later.
base_url: OmniParse Base URL for the API.
parse_options: Optional settings for OmniParse. Default is OmniParseOptions with default values.
"""
self.parse_options = parse_options or OmniParseOptions()
self.omniparse_client = OmniParseClient(api_key, base_url, max_timeout=self.parse_options.max_timeout)
@ -47,20 +50,32 @@ class OmniParse(BaseReader):
self.parse_options.result_type = result_type
async def _aload_data(
self,
file_path: Union[str, bytes, Path],
extra_info: Optional[dict] = None,
self,
file_path: Union[str, bytes, Path],
extra_info: Optional[dict] = None,
) -> List[Document]:
"""
Load data from the input file_path.
Args:
file_path: File path or file byte data.
extra_info: Optional dictionary containing additional information.
Returns:
List[Document]
"""
try:
if self.parse_type == OmniParseType.PDF:
# 目前先只支持pdf解析
# pdf parse
parsed_result = await self.omniparse_client.parse_pdf(file_path)
else:
# other parse use omniparse_client.parse_document
# For compatible byte data, additional filename is required
extra_info = extra_info or {}
filename = extra_info.get("filename") # 兼容字节数据要额外传filename
filename = extra_info.get("filename")
parsed_result = await self.omniparse_client.parse_document(file_path, bytes_filename=filename)
# 获取指定的结构数据
# Get the specified structured data based on result_type
content = getattr(parsed_result, self.result_type)
docs = [
Document(
@ -75,26 +90,51 @@ class OmniParse(BaseReader):
return docs
async def aload_data(
self,
file_path: Union[List[FileInput], FileInput],
extra_info: Optional[dict] = None,
self,
file_path: Union[List[FileInput], FileInput],
extra_info: Optional[dict] = None,
) -> List[Document]:
"""
Load data from the input file_path.
Args:
file_path: File path or file byte data.
extra_info: Optional dictionary containing additional information.
Notes:
This method ultimately calls _aload_data for processing.
Returns:
List[Document]
"""
docs = []
if isinstance(file_path, (str, bytes, Path)):
# 处理单个
# Processing single file
docs = await self._aload_data(file_path, extra_info)
elif isinstance(file_path, list):
# 并发处理多个
# Concurrently process multiple files
parse_jobs = [self._aload_data(file_item, extra_info) for file_item in file_path]
doc_ret_list = await run_jobs(jobs=parse_jobs, workers=self.parse_options.num_workers)
docs = [doc for docs in doc_ret_list for doc in docs]
return docs
def load_data(
self,
file_path: Union[List[FileInput], FileInput],
extra_info: Optional[dict] = None,
self,
file_path: Union[List[FileInput], FileInput],
extra_info: Optional[dict] = None,
) -> List[Document]:
"""Load data from the input path."""
NestAsyncio.apply_once() # 兼容异步嵌套调用
"""
Load data from the input file_path.
Args:
file_path: File path or file byte data.
extra_info: Optional dictionary containing additional information.
Notes:
This method ultimately calls aload_data for processing.
Returns:
List[Document]
"""
NestAsyncio.apply_once() # Ensure compatibility with nested async calls
return asyncio.run(self.aload_data(file_path, extra_info))

View file

@ -1,2 +0,0 @@
from .client import OmniParseClient
from .parse import OmniParse

View file

@ -8,7 +8,6 @@ from llama_index.core.embeddings import BaseEmbedding
from llama_index.core.indices.base import BaseIndex
from llama_index.core.schema import TextNode
from llama_index.core.vector_stores.types import VectorStoreQueryMode
from llama_parse import ResultType
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator
from metagpt.config2 import config
@ -218,23 +217,31 @@ class ObjectNode(TextNode):
class OmniParseType(str, Enum):
"""OmniParse解析类型"""
"""OmniParseType"""
PDF = "PDF"
DOCUMENT = "DOCUMENT"
class ParseResultType(str, Enum):
"""The result type for the parser."""
TXT = "text"
MD = "markdown"
JSON = "json"
class OmniParseOptions(BaseModel):
"""OmniParse可选配置"""
result_type: ResultType = Field(default=ResultType.MD, description="OmniParse解析返回的结果类型")
parse_type: OmniParseType = Field(default=OmniParseType.DOCUMENT, description="OmniParse解析类型默认文档类型")
max_timeout: Optional[int] = Field(default=120, description="OmniParse服务请求最大超时")
result_type: ParseResultType = Field(default=ParseResultType.MD, description="OmniParse result_type")
parse_type: OmniParseType = Field(default=OmniParseType.DOCUMENT, description="OmniParse parse_type")
max_timeout: Optional[int] = Field(default=120, description="Maximum timeout for OmniParse service requests")
num_workers: int = Field(
default=5,
gt=0,
lt=10,
description="多文件列表时并发请求数量",
description="Number of concurrent requests for multiple files",
)

View file

@ -3,28 +3,35 @@ import os
from pathlib import Path
from typing import Union
import aiofiles
import httpx
from metagpt.rag.schema import OmniParsedResult
from metagpt.utils.common import aread_bin
class OmniParseClient:
"""
OmniParse Server Client
OmniParse API Docs: https://docs.cognitivelab.in/api
This client interacts with the OmniParse server to parse different types of media, documents, and websites.
OmniParse API Documentation: https://docs.cognitivelab.in/api
Attributes:
ALLOWED_DOCUMENT_EXTENSIONS (set): A set of supported document file extensions.
ALLOWED_AUDIO_EXTENSIONS (set): A set of supported audio file extensions.
ALLOWED_VIDEO_EXTENSIONS (set): A set of supported video file extensions.
"""
ALLOWED_DOCUMENT_EXTENSIONS = {".pdf", ".ppt", ".pptx", ".doc", ".docx"}
ALLOWED_AUDIO_EXTENSIONS = {".mp3", ".wav", ".aac"}
ALLOWED_VIDEO_EXTENSIONS = {".mp4", ".mkv", ".avi", ".mov"}
def __init__(self, api_key=None, base_url="http://localhost:8000", max_timeout=120):
def __init__(self, api_key: str = None, base_url: str = "http://localhost:8000", max_timeout: int = 120):
"""
Args:
api_key: 默认 None 后续可用于鉴权
base_url: api 基础url
max_timeout: 请求最大超时时间单位s
api_key: Default None, can be used for authentication later.
base_url: Base URL for the API.
max_timeout: Maximum request timeout in seconds.
"""
self.api_key = api_key
self.base_url = base_url
@ -46,19 +53,20 @@ class OmniParseClient:
**kwargs,
) -> dict:
"""
请求api解析文档
Request OmniParse API to parse a document.
Args:
endpoint (str): API endpoint.
files (dict, optional): 请求文件数据.
params (dict, optional): 查询字符串参数.
data (dict, optional): 请求体数据.
json (dict, optional): 请求json数据.
headers (dict, optional): 请求头数据.
**kwargs: 其他 httpx.AsyncClient.request() 关键字参数
method (str, optional): HTTP method to use. Default is "POST".
files (dict, optional): Files to include in the request.
params (dict, optional): Query string parameters.
data (dict, optional): Form data to include in the request body.
json (dict, optional): JSON data to include in the request body.
headers (dict, optional): HTTP headers to include in the request.
**kwargs: Additional keyword arguments for httpx.AsyncClient.request()
Returns:
dict: 响应的json数据
dict: JSON response data.
"""
url = f"{self.base_url}{endpoint}"
method = method.upper()
@ -80,17 +88,94 @@ class OmniParseClient:
response.raise_for_status()
return response.json()
async def parse_document(self, filelike: Union[str, bytes, Path], bytes_filename: str = None) -> OmniParsedResult:
"""
Parse document-type data (supports ".pdf", ".ppt", ".pptx", ".doc", ".docx").
Args:
filelike: File path or file byte data.
bytes_filename: Filename for byte data, useful for determining MIME type for the HTTP request.
Raises:
ValueError: If the file extension is not allowed.
Returns:
OmniParsedResult: The result of the document parsing.
"""
self.verify_file_ext(filelike, self.ALLOWED_DOCUMENT_EXTENSIONS, bytes_filename)
file_info = await self.get_file_info(filelike, bytes_filename)
resp = await self._request_parse(self.parse_document_endpoint, files={"file": file_info})
data = OmniParsedResult(**resp)
return data
async def parse_pdf(self, filelike: Union[str, bytes, Path]) -> OmniParsedResult:
"""
Parse pdf document.
Args:
filelike: File path or file byte data.
Raises:
ValueError: If the file extension is not allowed.
Returns:
OmniParsedResult: The result of the pdf parsing.
"""
self.verify_file_ext(filelike, {".pdf"})
file_info = await self.get_file_info(filelike, only_bytes=True)
endpoint = f"{self.parse_document_endpoint}/pdf"
resp = await self._request_parse(endpoint=endpoint, files={"file": file_info})
data = OmniParsedResult(**resp)
return data
async def parse_video(self, filelike: Union[str, bytes, Path], bytes_filename: str = None) -> dict:
"""
Parse video-type data (supports ".mp4", ".mkv", ".avi", ".mov").
Args:
filelike: File path or file byte data.
bytes_filename: Filename for byte data, useful for determining MIME type for the HTTP request.
Raises:
ValueError: If the file extension is not allowed.
Returns:
dict: JSON response data.
"""
self.verify_file_ext(filelike, self.ALLOWED_VIDEO_EXTENSIONS, bytes_filename)
file_info = await self.get_file_info(filelike, bytes_filename)
return await self._request_parse(f"{self.parse_media_endpoint}/video", files={"file": file_info})
async def parse_audio(self, filelike: Union[str, bytes, Path], bytes_filename: str = None) -> dict:
"""
Parse audio-type data (supports ".mp3", ".wav", ".aac").
Args:
filelike: File path or file byte data.
bytes_filename: Filename for byte data, useful for determining MIME type for the HTTP request.
Raises:
ValueError: If the file extension is not allowed.
Returns:
dict: JSON response data.
"""
self.verify_file_ext(filelike, self.ALLOWED_AUDIO_EXTENSIONS, bytes_filename)
file_info = await self.get_file_info(filelike, bytes_filename)
return await self._request_parse(f"{self.parse_media_endpoint}/audio", files={"file": file_info})
@staticmethod
def verify_file_ext(filelike: Union[str, bytes, Path], allowed_file_extensions: set, bytes_filename: str = None):
"""
校验文件后缀
Verify the file extension.
Args:
filelike: 文件路径 or 文件字节数据
allowed_file_extensions: 允许的文件扩展名
bytes_filename: 当字节数据时使用这个参数校验
filelike: File path or file byte data.
allowed_file_extensions: Set of allowed file extensions.
bytes_filename: Filename to use for verification when `filelike` is byte data.
Raises:
ValueError
ValueError: If the file extension is not allowed.
Returns:
"""
@ -101,7 +186,7 @@ class OmniParseClient:
verify_file_path = bytes_filename
if not verify_file_path:
# 仅传字节数据时不校验
# Do not verify if only byte data is provided
return
file_ext = os.path.splitext(verify_file_path)[1].lower()
@ -112,28 +197,29 @@ class OmniParseClient:
async def get_file_info(
filelike: Union[str, bytes, Path],
bytes_filename: str = None,
only_bytes=True,
only_bytes: bool = False,
) -> Union[bytes, tuple]:
"""
获取文件字节信息
Get file information.
Args:
filelike: 文件数据
bytes_filename: 通过字节数据上传需要指定文件名称方便获取mime_type
only_bytes: 是否只需要字节数据
filelike: File path or file byte data.
bytes_filename: Filename to use when uploading byte data, useful for determining MIME type.
only_bytes: Whether to return only byte data. Default is False, which returns a tuple.
Raises:
ValueError
ValueError: If bytes_filename is not provided when filelike is bytes or if filelike is not a valid type.
Notes:
由于 parse_document 支持多种文件解析需要上传文件时指定文件的mime_type
Since `parse_document`,`parse_video`, `parse_audio` supports parsing various file types,
the MIME type of the file must be specified when uploading.
Returns:
[bytes, tuple]
Returns: [bytes, tuple]
Returns bytes if only_bytes is True, otherwise returns a tuple (filename, file_bytes, mime_type).
"""
if isinstance(filelike, (str, Path)):
filename = os.path.basename(str(filelike))
async with aiofiles.open(filelike, "rb") as file:
file_bytes = await file.read()
file_bytes = await aread_bin(filelike)
if only_bytes:
return file_bytes
@ -150,60 +236,3 @@ class OmniParseClient:
return bytes_filename, filelike, mime_type
else:
raise ValueError("filelike must be a string (file path) or bytes.")
async def parse_document(self, filelike: Union[str, bytes, Path], bytes_filename: str = None) -> OmniParsedResult:
"""
解析文档类型数据支持 ".pdf", ".ppt", ".pptx", ".doc", ".docs"
Args:
filelike: 文件路径 or 文件字节数据
bytes_filename: 字节数据名称方便获取mime_type 用于httpx请求
Raises
ValueError
Returns:
OmniParsedResult
"""
self.verify_file_ext(filelike, self.ALLOWED_DOCUMENT_EXTENSIONS, bytes_filename)
file_info = await self.get_file_info(filelike, bytes_filename, only_bytes=False)
resp = await self._request_parse(self.parse_document_endpoint, files={"file": file_info})
data = OmniParsedResult(**resp)
return data
async def parse_pdf(self, filelike: Union[str, bytes, Path]) -> OmniParsedResult:
"""
解析PDF文档
Args:
filelike: 文件路径 or 文件字节数据
Raises
ValueError
Returns:
OmniParsedResult
"""
self.verify_file_ext(filelike, {".pdf"})
file_info = await self.get_file_info(filelike)
endpoint = f"{self.parse_document_endpoint}/pdf"
resp = await self._request_parse(endpoint=endpoint, files={"file": file_info})
data = OmniParsedResult(**resp)
return data
async def parse_video(self, filelike: Union[str, bytes, Path], bytes_filename: str = None) -> dict:
"""解析视频"""
self.verify_file_ext(filelike, self.ALLOWED_VIDEO_EXTENSIONS, bytes_filename)
file_info = await self.get_file_info(filelike, bytes_filename, only_bytes=False)
return await self._request_parse(f"{self.parse_media_endpoint}/video", files={"file": file_info})
async def parse_audio(self, filelike: Union[str, bytes, Path], bytes_filename: str = None) -> dict:
"""解析音频"""
self.verify_file_ext(filelike, self.ALLOWED_AUDIO_EXTENSIONS, bytes_filename)
file_info = await self.get_file_info(filelike, bytes_filename, only_bytes=False)
return await self._request_parse(f"{self.parse_media_endpoint}/audio", files={"file": file_info})
async def parse_website(self, url: str) -> dict:
"""
解析网站
fixme:官方api还存在问题
"""
return await self._request_parse(f"{self.parse_website_endpoint}/parse", params={"url": url})

View file

@ -7,6 +7,7 @@ from llama_index.core.llms import MockLLM
from llama_index.core.schema import Document, NodeWithScore, TextNode
from metagpt.rag.engines import SimpleEngine
from metagpt.rag.parser import OmniParse
from metagpt.rag.retrievers import SimpleHybridRetriever
from metagpt.rag.retrievers.base import ModifiableRAGRetriever, PersistableRAGRetriever
from metagpt.rag.schema import BM25RetrieverConfig, ObjectNode
@ -39,7 +40,7 @@ class TestSimpleEngine:
@pytest.fixture
def mock_get_file_extractor(self, mocker):
return mocker.patch("metagpt.rag.engines.simple.SimpleEngine.get_file_extractor")
return mocker.patch("metagpt.rag.engines.simple.SimpleEngine._get_file_extractor")
def test_from_docs(
self,
@ -307,3 +308,17 @@ class TestSimpleEngine:
# Assert
assert "obj" in node.node.metadata
assert node.node.metadata["obj"] == expected_obj
def test_get_file_extractor(self, mocker):
# mock no omniparse config
mock_omniparse_config = mocker.patch("metagpt.rag.engines.simple.config.omniparse", autospec=True)
mock_omniparse_config.base_url = ""
file_extractor = SimpleEngine._get_file_extractor()
assert file_extractor == {}
# mock have omniparse config
mock_omniparse_config.base_url = "http://localhost:8000"
file_extractor = SimpleEngine._get_file_extractor(file_type=".pdf")
assert ".pdf" in file_extractor
assert isinstance(file_extractor[".pdf"], OmniParse)

View file

@ -1,32 +1,118 @@
import pytest
from llama_index.core import Document
from metagpt.const import EXAMPLE_DATA_PATH
from metagpt.rag.parser.omniparse import OmniParseClient
from metagpt.rag.schema import OmniParsedResult
from metagpt.rag.parser import OmniParse
from metagpt.rag.schema import (
OmniParsedResult,
OmniParseOptions,
OmniParseType,
ParseResultType,
)
from metagpt.utils.omniparse_client import OmniParseClient
# test data
TEST_DOCX = EXAMPLE_DATA_PATH / "omniparse/test01.docx"
TEST_PDF = EXAMPLE_DATA_PATH / "omniparse/test02.pdf"
TEST_VIDEO = EXAMPLE_DATA_PATH / "omniparse/test03.mp4"
TEST_AUDIO = EXAMPLE_DATA_PATH / "omniparse/test04.mp3"
class TestOmniParseClient:
parse_client = OmniParseClient()
# test data
TEST_DOCX = EXAMPLE_DATA_PATH / "parse/test01.docx"
TEST_PDF = EXAMPLE_DATA_PATH / "parse/test02.pdf"
TEST_VIDEO = EXAMPLE_DATA_PATH / "parse/test03.mp4"
TEST_AUDIO = EXAMPLE_DATA_PATH / "parse/test04.mp3"
@pytest.fixture
def request_parse(self, mocker):
def mock_request_parse(self, mocker):
return mocker.patch("metagpt.rag.parser.omniparse.OmniParseClient._request_parse")
@pytest.mark.asyncio
async def test_parse_pdf(self, request_parse):
async def test_parse_pdf(self, mock_request_parse):
mock_content = "#test title\ntest content"
mock_parsed_ret = OmniParsedResult(text=mock_content, markdown=mock_content)
request_parse.return_value = mock_parsed_ret.model_dump()
parse_ret = await self.parse_client.parse_pdf(self.TEST_PDF)
mock_request_parse.return_value = mock_parsed_ret.model_dump()
parse_ret = await self.parse_client.parse_pdf(TEST_PDF)
assert parse_ret == mock_parsed_ret
@pytest.mark.asyncio
async def test_parse_document(self, mock_request_parse):
mock_content = "#test title\ntest_parse_document"
mock_parsed_ret = OmniParsedResult(text=mock_content, markdown=mock_content)
mock_request_parse.return_value = mock_parsed_ret.model_dump()
with open(TEST_DOCX, "rb") as f:
file_bytes = f.read()
with pytest.raises(ValueError):
# bytes data must provide bytes_filename
await self.parse_client.parse_document(file_bytes)
parse_ret = await self.parse_client.parse_document(file_bytes, bytes_filename="test.docx")
assert parse_ret == mock_parsed_ret
@pytest.mark.asyncio
async def test_parse_video(self, mock_request_parse):
mock_content = "#test title\ntest_parse_video"
mock_request_parse.return_value = {
"text": mock_content,
"metadata": {},
}
with pytest.raises(ValueError):
# Wrong file extension test
await self.parse_client.parse_video(TEST_DOCX)
parse_ret = await self.parse_client.parse_video(TEST_VIDEO)
assert "text" in parse_ret and "metadata" in parse_ret
assert parse_ret["text"] == mock_content
@pytest.mark.asyncio
async def test_parse_audio(self, mock_request_parse):
mock_content = "#test title\ntest_parse_audio"
mock_request_parse.return_value = {
"text": mock_content,
"metadata": {},
}
parse_ret = await self.parse_client.parse_audio(TEST_AUDIO)
assert "text" in parse_ret and "metadata" in parse_ret
assert parse_ret["text"] == mock_content
class TestOmniParse:
def test_load_data(self):
pass
@pytest.fixture
def mock_omniparse(self):
parser = OmniParse(
parse_options=OmniParseOptions(
parse_type=OmniParseType.PDF,
result_type=ParseResultType.MD,
max_timeout=120,
num_workers=3,
)
)
return parser
@pytest.fixture
def mock_request_parse(self, mocker):
return mocker.patch("metagpt.rag.parser.omniparse.OmniParseClient._request_parse")
@pytest.mark.asyncio
async def test_load_data(self, mock_omniparse, mock_request_parse):
# mock
mock_content = "#test title\ntest content"
mock_parsed_ret = OmniParsedResult(text=mock_content, markdown=mock_content)
mock_request_parse.return_value = mock_parsed_ret.model_dump()
# single file
documents = mock_omniparse.load_data(file_path=TEST_PDF)
doc = documents[0]
assert isinstance(doc, Document)
assert doc.text == mock_parsed_ret.text == mock_parsed_ret.markdown
# multi files
file_paths = [TEST_DOCX, TEST_PDF]
mock_omniparse.parse_type = OmniParseType.DOCUMENT
documents = await mock_omniparse.aload_data(file_path=file_paths)
doc = documents[0]
# assert
assert isinstance(doc, Document)
assert len(documents) == len(file_paths)
assert doc.text == mock_parsed_ret.text == mock_parsed_ret.markdown