Merge pull request #1408 from seehi/feat-omniparse

Feat omniparse
This commit is contained in:
better629 2024-08-06 11:45:32 +08:00 committed by GitHub
commit 22e1009128
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 684 additions and 5 deletions

1
.gitattributes vendored
View file

@ -14,6 +14,7 @@
*.ico binary
*.jpeg binary
*.mp3 binary
*.mp4 binary
*.zip binary
*.bin binary

View file

@ -60,6 +60,10 @@ iflytek_api_secret: "YOUR_API_SECRET"
metagpt_tti_url: "YOUR_MODEL_URL"
omniparse:
api_key: "YOUR_API_KEY"
base_url: "YOUR_BASE_URL"
models:
# "YOUR_MODEL_NAME_1 or YOUR_API_TYPE_1": # model: "gpt-4-turbo" # or gpt-3.5-turbo
# api_type: "openai" # or azure / ollama / groq etc.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

64
examples/rag/omniparse.py Normal file
View file

@ -0,0 +1,64 @@
import asyncio
from metagpt.config2 import config
from metagpt.const import EXAMPLE_DATA_PATH
from metagpt.logs import logger
from metagpt.rag.parsers import OmniParse
from metagpt.rag.schema import OmniParseOptions, OmniParseType, ParseResultType
from metagpt.utils.omniparse_client import OmniParseClient
TEST_DOCX = EXAMPLE_DATA_PATH / "omniparse/test01.docx"
TEST_PDF = EXAMPLE_DATA_PATH / "omniparse/test02.pdf"
TEST_VIDEO = EXAMPLE_DATA_PATH / "omniparse/test03.mp4"
TEST_AUDIO = EXAMPLE_DATA_PATH / "omniparse/test04.mp3"
async def omniparse_client_example():
client = OmniParseClient(base_url=config.omniparse.base_url)
# docx
with open(TEST_DOCX, "rb") as f:
file_input = f.read()
document_parse_ret = await client.parse_document(file_input=file_input, bytes_filename="test_01.docx")
logger.info(document_parse_ret)
# pdf
pdf_parse_ret = await client.parse_pdf(file_input=TEST_PDF)
logger.info(pdf_parse_ret)
# video
video_parse_ret = await client.parse_video(file_input=TEST_VIDEO)
logger.info(video_parse_ret)
# audio
audio_parse_ret = await client.parse_audio(file_input=TEST_AUDIO)
logger.info(audio_parse_ret)
async def omniparse_example():
parser = OmniParse(
api_key=config.omniparse.api_key,
base_url=config.omniparse.base_url,
parse_options=OmniParseOptions(
parse_type=OmniParseType.PDF,
result_type=ParseResultType.MD,
max_timeout=120,
num_workers=3,
),
)
ret = parser.load_data(file_path=TEST_PDF)
logger.info(ret)
file_paths = [TEST_DOCX, TEST_PDF]
parser.parse_type = OmniParseType.DOCUMENT
ret = await parser.aload_data(file_path=file_paths)
logger.info(ret)
async def main():
await omniparse_client_example()
await omniparse_example()
if __name__ == "__main__":
asyncio.run(main())

View file

@ -2,7 +2,7 @@
import asyncio
from examples.rag_pipeline import DOC_PATH, QUESTION
from examples.rag.rag_pipeline import DOC_PATH, QUESTION
from metagpt.logs import logger
from metagpt.rag.engines import SimpleEngine
from metagpt.roles import Sales

View file

@ -13,6 +13,7 @@ from pydantic import BaseModel, model_validator
from metagpt.configs.browser_config import BrowserConfig
from metagpt.configs.embedding_config import EmbeddingConfig
from metagpt.configs.file_parser_config import OmniParseConfig
from metagpt.configs.llm_config import LLMConfig, LLMType
from metagpt.configs.mermaid_config import MermaidConfig
from metagpt.configs.redis_config import RedisConfig
@ -51,6 +52,9 @@ class Config(CLIParams, YamlModel):
# RAG Embedding
embedding: EmbeddingConfig = EmbeddingConfig()
# omniparse
omniparse: OmniParseConfig = OmniParseConfig()
# Global Proxy. Will be used if llm.proxy is not set
proxy: str = ""

View file

@ -0,0 +1,6 @@
from metagpt.utils.yaml_model import YamlModel
class OmniParseConfig(YamlModel):
api_key: str = ""
base_url: str = ""

View file

@ -14,6 +14,7 @@ from llama_index.core.llms import LLM
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core.postprocessor.types import BaseNodePostprocessor
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.core.readers.base import BaseReader
from llama_index.core.response_synthesizers import (
BaseSynthesizer,
get_response_synthesizer,
@ -28,6 +29,7 @@ from llama_index.core.schema import (
TransformComponent,
)
from metagpt.config2 import config
from metagpt.rag.factories import (
get_index,
get_rag_embedding,
@ -36,6 +38,7 @@ from metagpt.rag.factories import (
get_retriever,
)
from metagpt.rag.interface import NoEmbedding, RAGObject
from metagpt.rag.parsers import OmniParse
from metagpt.rag.retrievers.base import ModifiableRAGRetriever, PersistableRAGRetriever
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
from metagpt.rag.schema import (
@ -44,6 +47,9 @@ from metagpt.rag.schema import (
BaseRetrieverConfig,
BM25RetrieverConfig,
ObjectNode,
OmniParseOptions,
OmniParseType,
ParseResultType,
)
from metagpt.utils.common import import_class
@ -100,7 +106,10 @@ class SimpleEngine(RetrieverQueryEngine):
if not input_dir and not input_files:
raise ValueError("Must provide either `input_dir` or `input_files`.")
documents = SimpleDirectoryReader(input_dir=input_dir, input_files=input_files).load_data()
file_extractor = cls._get_file_extractor()
documents = SimpleDirectoryReader(
input_dir=input_dir, input_files=input_files, file_extractor=file_extractor
).load_data()
cls._fix_document_metadata(documents)
transformations = transformations or cls._default_transformations()
@ -301,3 +310,23 @@ class SimpleEngine(RetrieverQueryEngine):
@staticmethod
def _default_transformations():
return [SentenceSplitter()]
@staticmethod
def _get_file_extractor() -> dict[str:BaseReader]:
"""
Get the file extractor.
Currently, only PDF use OmniParse. Other document types use the built-in reader from llama_index.
Returns:
dict[file_type: BaseReader]
"""
file_extractor: dict[str:BaseReader] = {}
if config.omniparse.base_url:
pdf_parser = OmniParse(
api_key=config.omniparse.api_key,
base_url=config.omniparse.base_url,
parse_options=OmniParseOptions(parse_type=OmniParseType.PDF, result_type=ParseResultType.MD),
)
file_extractor[".pdf"] = pdf_parser
return file_extractor

View file

@ -0,0 +1,3 @@
from metagpt.rag.parsers.omniparse import OmniParse
__all__ = ["OmniParse"]

View file

@ -0,0 +1,139 @@
import asyncio
from fileinput import FileInput
from pathlib import Path
from typing import List, Optional, Union
from llama_index.core import Document
from llama_index.core.async_utils import run_jobs
from llama_index.core.readers.base import BaseReader
from metagpt.logs import logger
from metagpt.rag.schema import OmniParseOptions, OmniParseType, ParseResultType
from metagpt.utils.async_helper import NestAsyncio
from metagpt.utils.omniparse_client import OmniParseClient
class OmniParse(BaseReader):
"""OmniParse"""
def __init__(
self, api_key: str = None, base_url: str = "http://localhost:8000", parse_options: OmniParseOptions = None
):
"""
Args:
api_key: Default None, can be used for authentication later.
base_url: OmniParse Base URL for the API.
parse_options: Optional settings for OmniParse. Default is OmniParseOptions with default values.
"""
self.parse_options = parse_options or OmniParseOptions()
self.omniparse_client = OmniParseClient(api_key, base_url, max_timeout=self.parse_options.max_timeout)
@property
def parse_type(self):
return self.parse_options.parse_type
@property
def result_type(self):
return self.parse_options.result_type
@parse_type.setter
def parse_type(self, parse_type: Union[str, OmniParseType]):
if isinstance(parse_type, str):
parse_type = OmniParseType(parse_type)
self.parse_options.parse_type = parse_type
@result_type.setter
def result_type(self, result_type: Union[str, ParseResultType]):
if isinstance(result_type, str):
result_type = ParseResultType(result_type)
self.parse_options.result_type = result_type
async def _aload_data(
self,
file_path: Union[str, bytes, Path],
extra_info: Optional[dict] = None,
) -> List[Document]:
"""
Load data from the input file_path.
Args:
file_path: File path or file byte data.
extra_info: Optional dictionary containing additional information.
Returns:
List[Document]
"""
try:
if self.parse_type == OmniParseType.PDF:
# pdf parse
parsed_result = await self.omniparse_client.parse_pdf(file_path)
else:
# other parse use omniparse_client.parse_document
# For compatible byte data, additional filename is required
extra_info = extra_info or {}
filename = extra_info.get("filename")
parsed_result = await self.omniparse_client.parse_document(file_path, bytes_filename=filename)
# Get the specified structured data based on result_type
content = getattr(parsed_result, self.result_type)
docs = [
Document(
text=content,
metadata=extra_info or {},
)
]
except Exception as e:
logger.error(f"OMNI Parse Error: {e}")
docs = []
return docs
async def aload_data(
self,
file_path: Union[List[FileInput], FileInput],
extra_info: Optional[dict] = None,
) -> List[Document]:
"""
Load data from the input file_path.
Args:
file_path: File path or file byte data.
extra_info: Optional dictionary containing additional information.
Notes:
This method ultimately calls _aload_data for processing.
Returns:
List[Document]
"""
docs = []
if isinstance(file_path, (str, bytes, Path)):
# Processing single file
docs = await self._aload_data(file_path, extra_info)
elif isinstance(file_path, list):
# Concurrently process multiple files
parse_jobs = [self._aload_data(file_item, extra_info) for file_item in file_path]
doc_ret_list = await run_jobs(jobs=parse_jobs, workers=self.parse_options.num_workers)
docs = [doc for docs in doc_ret_list for doc in docs]
return docs
def load_data(
self,
file_path: Union[List[FileInput], FileInput],
extra_info: Optional[dict] = None,
) -> List[Document]:
"""
Load data from the input file_path.
Args:
file_path: File path or file byte data.
extra_info: Optional dictionary containing additional information.
Notes:
This method ultimately calls aload_data for processing.
Returns:
List[Document]
"""
NestAsyncio.apply_once() # Ensure compatibility with nested async calls
return asyncio.run(self.aload_data(file_path, extra_info))

View file

@ -1,7 +1,7 @@
"""RAG schemas."""
from enum import Enum
from pathlib import Path
from typing import Any, ClassVar, Literal, Optional, Union
from typing import Any, ClassVar, List, Literal, Optional, Union
from chromadb.api.types import CollectionMetadata
from llama_index.core.embeddings import BaseEmbedding
@ -214,3 +214,51 @@ class ObjectNode(TextNode):
)
return metadata.model_dump()
class OmniParseType(str, Enum):
"""OmniParseType"""
PDF = "PDF"
DOCUMENT = "DOCUMENT"
class ParseResultType(str, Enum):
"""The result type for the parser."""
TXT = "text"
MD = "markdown"
JSON = "json"
class OmniParseOptions(BaseModel):
"""OmniParse Options config"""
result_type: ParseResultType = Field(default=ParseResultType.MD, description="OmniParse result_type")
parse_type: OmniParseType = Field(default=OmniParseType.DOCUMENT, description="OmniParse parse_type")
max_timeout: Optional[int] = Field(default=120, description="Maximum timeout for OmniParse service requests")
num_workers: int = Field(
default=5,
gt=0,
lt=10,
description="Number of concurrent requests for multiple files",
)
class OminParseImage(BaseModel):
image: str = Field(default="", description="image str bytes")
image_name: str = Field(default="", description="image name")
image_info: Optional[dict] = Field(default={}, description="image info")
class OmniParsedResult(BaseModel):
markdown: str = Field(default="", description="markdown text")
text: str = Field(default="", description="plain text")
images: Optional[List[OminParseImage]] = Field(default=[], description="images")
metadata: Optional[dict] = Field(default={}, description="metadata")
@model_validator(mode="before")
def set_markdown(cls, values):
if not values.get("markdown"):
values["markdown"] = values.get("text")
return values

View file

@ -0,0 +1,239 @@
import mimetypes
import os
from pathlib import Path
from typing import Union
import httpx
from metagpt.rag.schema import OmniParsedResult
from metagpt.utils.common import aread_bin
class OmniParseClient:
"""
OmniParse Server Client
This client interacts with the OmniParse server to parse different types of media, documents.
OmniParse API Documentation: https://docs.cognitivelab.in/api
Attributes:
ALLOWED_DOCUMENT_EXTENSIONS (set): A set of supported document file extensions.
ALLOWED_AUDIO_EXTENSIONS (set): A set of supported audio file extensions.
ALLOWED_VIDEO_EXTENSIONS (set): A set of supported video file extensions.
"""
ALLOWED_DOCUMENT_EXTENSIONS = {".pdf", ".ppt", ".pptx", ".doc", ".docx"}
ALLOWED_AUDIO_EXTENSIONS = {".mp3", ".wav", ".aac"}
ALLOWED_VIDEO_EXTENSIONS = {".mp4", ".mkv", ".avi", ".mov"}
def __init__(self, api_key: str = None, base_url: str = "http://localhost:8000", max_timeout: int = 120):
"""
Args:
api_key: Default None, can be used for authentication later.
base_url: Base URL for the API.
max_timeout: Maximum request timeout in seconds.
"""
self.api_key = api_key
self.base_url = base_url
self.max_timeout = max_timeout
self.parse_media_endpoint = "/parse_media"
self.parse_website_endpoint = "/parse_website"
self.parse_document_endpoint = "/parse_document"
async def _request_parse(
self,
endpoint: str,
method: str = "POST",
files: dict = None,
params: dict = None,
data: dict = None,
json: dict = None,
headers: dict = None,
**kwargs,
) -> dict:
"""
Request OmniParse API to parse a document.
Args:
endpoint (str): API endpoint.
method (str, optional): HTTP method to use. Default is "POST".
files (dict, optional): Files to include in the request.
params (dict, optional): Query string parameters.
data (dict, optional): Form data to include in the request body.
json (dict, optional): JSON data to include in the request body.
headers (dict, optional): HTTP headers to include in the request.
**kwargs: Additional keyword arguments for httpx.AsyncClient.request()
Returns:
dict: JSON response data.
"""
url = f"{self.base_url}{endpoint}"
method = method.upper()
headers = headers or {}
_headers = {"Authorization": f"Bearer {self.api_key}"} if self.api_key else {}
headers.update(**_headers)
async with httpx.AsyncClient() as client:
response = await client.request(
url=url,
method=method,
files=files,
params=params,
json=json,
data=data,
headers=headers,
timeout=self.max_timeout,
**kwargs,
)
response.raise_for_status()
return response.json()
async def parse_document(self, file_input: Union[str, bytes, Path], bytes_filename: str = None) -> OmniParsedResult:
"""
Parse document-type data (supports ".pdf", ".ppt", ".pptx", ".doc", ".docx").
Args:
file_input: File path or file byte data.
bytes_filename: Filename for byte data, useful for determining MIME type for the HTTP request.
Raises:
ValueError: If the file extension is not allowed.
Returns:
OmniParsedResult: The result of the document parsing.
"""
self.verify_file_ext(file_input, self.ALLOWED_DOCUMENT_EXTENSIONS, bytes_filename)
file_info = await self.get_file_info(file_input, bytes_filename)
resp = await self._request_parse(self.parse_document_endpoint, files={"file": file_info})
data = OmniParsedResult(**resp)
return data
async def parse_pdf(self, file_input: Union[str, bytes, Path]) -> OmniParsedResult:
"""
Parse pdf document.
Args:
file_input: File path or file byte data.
Raises:
ValueError: If the file extension is not allowed.
Returns:
OmniParsedResult: The result of the pdf parsing.
"""
self.verify_file_ext(file_input, {".pdf"})
# parse_pdf supports parsing by accepting only the byte data of the file.
file_info = await self.get_file_info(file_input, only_bytes=True)
endpoint = f"{self.parse_document_endpoint}/pdf"
resp = await self._request_parse(endpoint=endpoint, files={"file": file_info})
data = OmniParsedResult(**resp)
return data
async def parse_video(self, file_input: Union[str, bytes, Path], bytes_filename: str = None) -> dict:
"""
Parse video-type data (supports ".mp4", ".mkv", ".avi", ".mov").
Args:
file_input: File path or file byte data.
bytes_filename: Filename for byte data, useful for determining MIME type for the HTTP request.
Raises:
ValueError: If the file extension is not allowed.
Returns:
dict: JSON response data.
"""
self.verify_file_ext(file_input, self.ALLOWED_VIDEO_EXTENSIONS, bytes_filename)
file_info = await self.get_file_info(file_input, bytes_filename)
return await self._request_parse(f"{self.parse_media_endpoint}/video", files={"file": file_info})
async def parse_audio(self, file_input: Union[str, bytes, Path], bytes_filename: str = None) -> dict:
"""
Parse audio-type data (supports ".mp3", ".wav", ".aac").
Args:
file_input: File path or file byte data.
bytes_filename: Filename for byte data, useful for determining MIME type for the HTTP request.
Raises:
ValueError: If the file extension is not allowed.
Returns:
dict: JSON response data.
"""
self.verify_file_ext(file_input, self.ALLOWED_AUDIO_EXTENSIONS, bytes_filename)
file_info = await self.get_file_info(file_input, bytes_filename)
return await self._request_parse(f"{self.parse_media_endpoint}/audio", files={"file": file_info})
@staticmethod
def verify_file_ext(file_input: Union[str, bytes, Path], allowed_file_extensions: set, bytes_filename: str = None):
"""
Verify the file extension.
Args:
file_input: File path or file byte data.
allowed_file_extensions: Set of allowed file extensions.
bytes_filename: Filename to use for verification when `file_input` is byte data.
Raises:
ValueError: If the file extension is not allowed.
Returns:
"""
verify_file_path = None
if isinstance(file_input, (str, Path)):
verify_file_path = str(file_input)
elif isinstance(file_input, bytes) and bytes_filename:
verify_file_path = bytes_filename
if not verify_file_path:
# Do not verify if only byte data is provided
return
file_ext = os.path.splitext(verify_file_path)[1].lower()
if file_ext not in allowed_file_extensions:
raise ValueError(f"Not allowed {file_ext} File extension must be one of {allowed_file_extensions}")
@staticmethod
async def get_file_info(
file_input: Union[str, bytes, Path],
bytes_filename: str = None,
only_bytes: bool = False,
) -> Union[bytes, tuple]:
"""
Get file information.
Args:
file_input: File path or file byte data.
bytes_filename: Filename to use when uploading byte data, useful for determining MIME type.
only_bytes: Whether to return only byte data. Default is False, which returns a tuple.
Raises:
ValueError: If bytes_filename is not provided when file_input is bytes or if file_input is not a valid type.
Notes:
Since `parse_document`,`parse_video`, `parse_audio` supports parsing various file types,
the MIME type of the file must be specified when uploading.
Returns: [bytes, tuple]
Returns bytes if only_bytes is True, otherwise returns a tuple (filename, file_bytes, mime_type).
"""
if isinstance(file_input, (str, Path)):
filename = os.path.basename(str(file_input))
file_bytes = await aread_bin(file_input)
if only_bytes:
return file_bytes
mime_type = mimetypes.guess_type(file_input)[0]
return filename, file_bytes, mime_type
elif isinstance(file_input, bytes):
if only_bytes:
return file_input
if not bytes_filename:
raise ValueError("bytes_filename must be set when passing bytes")
mime_type = mimetypes.guess_type(bytes_filename)[0]
return bytes_filename, file_input, mime_type
else:
raise ValueError("file_input must be a string (file path) or bytes.")

View file

@ -7,6 +7,7 @@ from llama_index.core.llms import MockLLM
from llama_index.core.schema import Document, NodeWithScore, TextNode
from metagpt.rag.engines import SimpleEngine
from metagpt.rag.parsers import OmniParse
from metagpt.rag.retrievers import SimpleHybridRetriever
from metagpt.rag.retrievers.base import ModifiableRAGRetriever, PersistableRAGRetriever
from metagpt.rag.schema import BM25RetrieverConfig, ObjectNode
@ -37,6 +38,10 @@ class TestSimpleEngine:
def mock_get_response_synthesizer(self, mocker):
return mocker.patch("metagpt.rag.engines.simple.get_response_synthesizer")
@pytest.fixture
def mock_get_file_extractor(self, mocker):
return mocker.patch("metagpt.rag.engines.simple.SimpleEngine._get_file_extractor")
def test_from_docs(
self,
mocker,
@ -44,6 +49,7 @@ class TestSimpleEngine:
mock_get_retriever,
mock_get_rankers,
mock_get_response_synthesizer,
mock_get_file_extractor,
):
# Mock
mock_simple_directory_reader.return_value.load_data.return_value = [
@ -53,6 +59,8 @@ class TestSimpleEngine:
mock_get_retriever.return_value = mocker.MagicMock()
mock_get_rankers.return_value = [mocker.MagicMock()]
mock_get_response_synthesizer.return_value = mocker.MagicMock()
file_extractor = mocker.MagicMock()
mock_get_file_extractor.return_value = file_extractor
# Setup
input_dir = "test_dir"
@ -75,7 +83,9 @@ class TestSimpleEngine:
)
# Assert
mock_simple_directory_reader.assert_called_once_with(input_dir=input_dir, input_files=input_files)
mock_simple_directory_reader.assert_called_once_with(
input_dir=input_dir, input_files=input_files, file_extractor=file_extractor
)
mock_get_retriever.assert_called_once()
mock_get_rankers.assert_called_once()
mock_get_response_synthesizer.assert_called_once_with(llm=llm)
@ -298,3 +308,17 @@ class TestSimpleEngine:
# Assert
assert "obj" in node.node.metadata
assert node.node.metadata["obj"] == expected_obj
def test_get_file_extractor(self, mocker):
# mock no omniparse config
mock_omniparse_config = mocker.patch("metagpt.rag.engines.simple.config.omniparse", autospec=True)
mock_omniparse_config.base_url = ""
file_extractor = SimpleEngine._get_file_extractor()
assert file_extractor == {}
# mock have omniparse config
mock_omniparse_config.base_url = "http://localhost:8000"
file_extractor = SimpleEngine._get_file_extractor()
assert ".pdf" in file_extractor
assert isinstance(file_extractor[".pdf"], OmniParse)

View file

@ -0,0 +1,118 @@
import pytest
from llama_index.core import Document
from metagpt.const import EXAMPLE_DATA_PATH
from metagpt.rag.parsers import OmniParse
from metagpt.rag.schema import (
OmniParsedResult,
OmniParseOptions,
OmniParseType,
ParseResultType,
)
from metagpt.utils.omniparse_client import OmniParseClient
# test data
TEST_DOCX = EXAMPLE_DATA_PATH / "omniparse/test01.docx"
TEST_PDF = EXAMPLE_DATA_PATH / "omniparse/test02.pdf"
TEST_VIDEO = EXAMPLE_DATA_PATH / "omniparse/test03.mp4"
TEST_AUDIO = EXAMPLE_DATA_PATH / "omniparse/test04.mp3"
class TestOmniParseClient:
parse_client = OmniParseClient()
@pytest.fixture
def mock_request_parse(self, mocker):
return mocker.patch("metagpt.rag.parsers.omniparse.OmniParseClient._request_parse")
@pytest.mark.asyncio
async def test_parse_pdf(self, mock_request_parse):
mock_content = "#test title\ntest content"
mock_parsed_ret = OmniParsedResult(text=mock_content, markdown=mock_content)
mock_request_parse.return_value = mock_parsed_ret.model_dump()
parse_ret = await self.parse_client.parse_pdf(TEST_PDF)
assert parse_ret == mock_parsed_ret
@pytest.mark.asyncio
async def test_parse_document(self, mock_request_parse):
mock_content = "#test title\ntest_parse_document"
mock_parsed_ret = OmniParsedResult(text=mock_content, markdown=mock_content)
mock_request_parse.return_value = mock_parsed_ret.model_dump()
with open(TEST_DOCX, "rb") as f:
file_bytes = f.read()
with pytest.raises(ValueError):
# bytes data must provide bytes_filename
await self.parse_client.parse_document(file_bytes)
parse_ret = await self.parse_client.parse_document(file_bytes, bytes_filename="test.docx")
assert parse_ret == mock_parsed_ret
@pytest.mark.asyncio
async def test_parse_video(self, mock_request_parse):
mock_content = "#test title\ntest_parse_video"
mock_request_parse.return_value = {
"text": mock_content,
"metadata": {},
}
with pytest.raises(ValueError):
# Wrong file extension test
await self.parse_client.parse_video(TEST_DOCX)
parse_ret = await self.parse_client.parse_video(TEST_VIDEO)
assert "text" in parse_ret and "metadata" in parse_ret
assert parse_ret["text"] == mock_content
@pytest.mark.asyncio
async def test_parse_audio(self, mock_request_parse):
mock_content = "#test title\ntest_parse_audio"
mock_request_parse.return_value = {
"text": mock_content,
"metadata": {},
}
parse_ret = await self.parse_client.parse_audio(TEST_AUDIO)
assert "text" in parse_ret and "metadata" in parse_ret
assert parse_ret["text"] == mock_content
class TestOmniParse:
@pytest.fixture
def mock_omniparse(self):
parser = OmniParse(
parse_options=OmniParseOptions(
parse_type=OmniParseType.PDF,
result_type=ParseResultType.MD,
max_timeout=120,
num_workers=3,
)
)
return parser
@pytest.fixture
def mock_request_parse(self, mocker):
return mocker.patch("metagpt.rag.parsers.omniparse.OmniParseClient._request_parse")
@pytest.mark.asyncio
async def test_load_data(self, mock_omniparse, mock_request_parse):
# mock
mock_content = "#test title\ntest content"
mock_parsed_ret = OmniParsedResult(text=mock_content, markdown=mock_content)
mock_request_parse.return_value = mock_parsed_ret.model_dump()
# single file
documents = mock_omniparse.load_data(file_path=TEST_PDF)
doc = documents[0]
assert isinstance(doc, Document)
assert doc.text == mock_parsed_ret.text == mock_parsed_ret.markdown
# multi files
file_paths = [TEST_DOCX, TEST_PDF]
mock_omniparse.parse_type = OmniParseType.DOCUMENT
documents = await mock_omniparse.aload_data(file_path=file_paths)
doc = documents[0]
# assert
assert isinstance(doc, Document)
assert len(documents) == len(file_paths)
assert doc.text == mock_parsed_ret.text == mock_parsed_ret.markdown