mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-08 15:05:17 +02:00
commit
22e1009128
19 changed files with 684 additions and 5 deletions
1
.gitattributes
vendored
1
.gitattributes
vendored
|
|
@ -14,6 +14,7 @@
|
|||
*.ico binary
|
||||
*.jpeg binary
|
||||
*.mp3 binary
|
||||
*.mp4 binary
|
||||
*.zip binary
|
||||
*.bin binary
|
||||
|
||||
|
|
|
|||
|
|
@ -60,6 +60,10 @@ iflytek_api_secret: "YOUR_API_SECRET"
|
|||
|
||||
metagpt_tti_url: "YOUR_MODEL_URL"
|
||||
|
||||
omniparse:
|
||||
api_key: "YOUR_API_KEY"
|
||||
base_url: "YOUR_BASE_URL"
|
||||
|
||||
models:
|
||||
# "YOUR_MODEL_NAME_1 or YOUR_API_TYPE_1": # model: "gpt-4-turbo" # or gpt-3.5-turbo
|
||||
# api_type: "openai" # or azure / ollama / groq etc.
|
||||
|
|
|
|||
BIN
examples/data/omniparse/test01.docx
Normal file
BIN
examples/data/omniparse/test01.docx
Normal file
Binary file not shown.
BIN
examples/data/omniparse/test02.pdf
Normal file
BIN
examples/data/omniparse/test02.pdf
Normal file
Binary file not shown.
BIN
examples/data/omniparse/test03.mp4
Normal file
BIN
examples/data/omniparse/test03.mp4
Normal file
Binary file not shown.
BIN
examples/data/omniparse/test04.mp3
Normal file
BIN
examples/data/omniparse/test04.mp3
Normal file
Binary file not shown.
64
examples/rag/omniparse.py
Normal file
64
examples/rag/omniparse.py
Normal file
|
|
@ -0,0 +1,64 @@
|
|||
import asyncio
|
||||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.const import EXAMPLE_DATA_PATH
|
||||
from metagpt.logs import logger
|
||||
from metagpt.rag.parsers import OmniParse
|
||||
from metagpt.rag.schema import OmniParseOptions, OmniParseType, ParseResultType
|
||||
from metagpt.utils.omniparse_client import OmniParseClient
|
||||
|
||||
TEST_DOCX = EXAMPLE_DATA_PATH / "omniparse/test01.docx"
|
||||
TEST_PDF = EXAMPLE_DATA_PATH / "omniparse/test02.pdf"
|
||||
TEST_VIDEO = EXAMPLE_DATA_PATH / "omniparse/test03.mp4"
|
||||
TEST_AUDIO = EXAMPLE_DATA_PATH / "omniparse/test04.mp3"
|
||||
|
||||
|
||||
async def omniparse_client_example():
|
||||
client = OmniParseClient(base_url=config.omniparse.base_url)
|
||||
|
||||
# docx
|
||||
with open(TEST_DOCX, "rb") as f:
|
||||
file_input = f.read()
|
||||
document_parse_ret = await client.parse_document(file_input=file_input, bytes_filename="test_01.docx")
|
||||
logger.info(document_parse_ret)
|
||||
|
||||
# pdf
|
||||
pdf_parse_ret = await client.parse_pdf(file_input=TEST_PDF)
|
||||
logger.info(pdf_parse_ret)
|
||||
|
||||
# video
|
||||
video_parse_ret = await client.parse_video(file_input=TEST_VIDEO)
|
||||
logger.info(video_parse_ret)
|
||||
|
||||
# audio
|
||||
audio_parse_ret = await client.parse_audio(file_input=TEST_AUDIO)
|
||||
logger.info(audio_parse_ret)
|
||||
|
||||
|
||||
async def omniparse_example():
|
||||
parser = OmniParse(
|
||||
api_key=config.omniparse.api_key,
|
||||
base_url=config.omniparse.base_url,
|
||||
parse_options=OmniParseOptions(
|
||||
parse_type=OmniParseType.PDF,
|
||||
result_type=ParseResultType.MD,
|
||||
max_timeout=120,
|
||||
num_workers=3,
|
||||
),
|
||||
)
|
||||
ret = parser.load_data(file_path=TEST_PDF)
|
||||
logger.info(ret)
|
||||
|
||||
file_paths = [TEST_DOCX, TEST_PDF]
|
||||
parser.parse_type = OmniParseType.DOCUMENT
|
||||
ret = await parser.aload_data(file_path=file_paths)
|
||||
logger.info(ret)
|
||||
|
||||
|
||||
async def main():
|
||||
await omniparse_client_example()
|
||||
await omniparse_example()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
import asyncio
|
||||
|
||||
from examples.rag_pipeline import DOC_PATH, QUESTION
|
||||
from examples.rag.rag_pipeline import DOC_PATH, QUESTION
|
||||
from metagpt.logs import logger
|
||||
from metagpt.rag.engines import SimpleEngine
|
||||
from metagpt.roles import Sales
|
||||
|
|
@ -13,6 +13,7 @@ from pydantic import BaseModel, model_validator
|
|||
|
||||
from metagpt.configs.browser_config import BrowserConfig
|
||||
from metagpt.configs.embedding_config import EmbeddingConfig
|
||||
from metagpt.configs.file_parser_config import OmniParseConfig
|
||||
from metagpt.configs.llm_config import LLMConfig, LLMType
|
||||
from metagpt.configs.mermaid_config import MermaidConfig
|
||||
from metagpt.configs.redis_config import RedisConfig
|
||||
|
|
@ -51,6 +52,9 @@ class Config(CLIParams, YamlModel):
|
|||
# RAG Embedding
|
||||
embedding: EmbeddingConfig = EmbeddingConfig()
|
||||
|
||||
# omniparse
|
||||
omniparse: OmniParseConfig = OmniParseConfig()
|
||||
|
||||
# Global Proxy. Will be used if llm.proxy is not set
|
||||
proxy: str = ""
|
||||
|
||||
|
|
|
|||
6
metagpt/configs/file_parser_config.py
Normal file
6
metagpt/configs/file_parser_config.py
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
from metagpt.utils.yaml_model import YamlModel
|
||||
|
||||
|
||||
class OmniParseConfig(YamlModel):
|
||||
api_key: str = ""
|
||||
base_url: str = ""
|
||||
|
|
@ -14,6 +14,7 @@ from llama_index.core.llms import LLM
|
|||
from llama_index.core.node_parser import SentenceSplitter
|
||||
from llama_index.core.postprocessor.types import BaseNodePostprocessor
|
||||
from llama_index.core.query_engine import RetrieverQueryEngine
|
||||
from llama_index.core.readers.base import BaseReader
|
||||
from llama_index.core.response_synthesizers import (
|
||||
BaseSynthesizer,
|
||||
get_response_synthesizer,
|
||||
|
|
@ -28,6 +29,7 @@ from llama_index.core.schema import (
|
|||
TransformComponent,
|
||||
)
|
||||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.rag.factories import (
|
||||
get_index,
|
||||
get_rag_embedding,
|
||||
|
|
@ -36,6 +38,7 @@ from metagpt.rag.factories import (
|
|||
get_retriever,
|
||||
)
|
||||
from metagpt.rag.interface import NoEmbedding, RAGObject
|
||||
from metagpt.rag.parsers import OmniParse
|
||||
from metagpt.rag.retrievers.base import ModifiableRAGRetriever, PersistableRAGRetriever
|
||||
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
|
||||
from metagpt.rag.schema import (
|
||||
|
|
@ -44,6 +47,9 @@ from metagpt.rag.schema import (
|
|||
BaseRetrieverConfig,
|
||||
BM25RetrieverConfig,
|
||||
ObjectNode,
|
||||
OmniParseOptions,
|
||||
OmniParseType,
|
||||
ParseResultType,
|
||||
)
|
||||
from metagpt.utils.common import import_class
|
||||
|
||||
|
|
@ -100,7 +106,10 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
if not input_dir and not input_files:
|
||||
raise ValueError("Must provide either `input_dir` or `input_files`.")
|
||||
|
||||
documents = SimpleDirectoryReader(input_dir=input_dir, input_files=input_files).load_data()
|
||||
file_extractor = cls._get_file_extractor()
|
||||
documents = SimpleDirectoryReader(
|
||||
input_dir=input_dir, input_files=input_files, file_extractor=file_extractor
|
||||
).load_data()
|
||||
cls._fix_document_metadata(documents)
|
||||
|
||||
transformations = transformations or cls._default_transformations()
|
||||
|
|
@ -301,3 +310,23 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
@staticmethod
|
||||
def _default_transformations():
|
||||
return [SentenceSplitter()]
|
||||
|
||||
@staticmethod
|
||||
def _get_file_extractor() -> dict[str:BaseReader]:
|
||||
"""
|
||||
Get the file extractor.
|
||||
Currently, only PDF use OmniParse. Other document types use the built-in reader from llama_index.
|
||||
|
||||
Returns:
|
||||
dict[file_type: BaseReader]
|
||||
"""
|
||||
file_extractor: dict[str:BaseReader] = {}
|
||||
if config.omniparse.base_url:
|
||||
pdf_parser = OmniParse(
|
||||
api_key=config.omniparse.api_key,
|
||||
base_url=config.omniparse.base_url,
|
||||
parse_options=OmniParseOptions(parse_type=OmniParseType.PDF, result_type=ParseResultType.MD),
|
||||
)
|
||||
file_extractor[".pdf"] = pdf_parser
|
||||
|
||||
return file_extractor
|
||||
|
|
|
|||
3
metagpt/rag/parsers/__init__.py
Normal file
3
metagpt/rag/parsers/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
from metagpt.rag.parsers.omniparse import OmniParse
|
||||
|
||||
__all__ = ["OmniParse"]
|
||||
139
metagpt/rag/parsers/omniparse.py
Normal file
139
metagpt/rag/parsers/omniparse.py
Normal file
|
|
@ -0,0 +1,139 @@
|
|||
import asyncio
|
||||
from fileinput import FileInput
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from llama_index.core import Document
|
||||
from llama_index.core.async_utils import run_jobs
|
||||
from llama_index.core.readers.base import BaseReader
|
||||
|
||||
from metagpt.logs import logger
|
||||
from metagpt.rag.schema import OmniParseOptions, OmniParseType, ParseResultType
|
||||
from metagpt.utils.async_helper import NestAsyncio
|
||||
from metagpt.utils.omniparse_client import OmniParseClient
|
||||
|
||||
|
||||
class OmniParse(BaseReader):
|
||||
"""OmniParse"""
|
||||
|
||||
def __init__(
|
||||
self, api_key: str = None, base_url: str = "http://localhost:8000", parse_options: OmniParseOptions = None
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
api_key: Default None, can be used for authentication later.
|
||||
base_url: OmniParse Base URL for the API.
|
||||
parse_options: Optional settings for OmniParse. Default is OmniParseOptions with default values.
|
||||
"""
|
||||
self.parse_options = parse_options or OmniParseOptions()
|
||||
self.omniparse_client = OmniParseClient(api_key, base_url, max_timeout=self.parse_options.max_timeout)
|
||||
|
||||
@property
|
||||
def parse_type(self):
|
||||
return self.parse_options.parse_type
|
||||
|
||||
@property
|
||||
def result_type(self):
|
||||
return self.parse_options.result_type
|
||||
|
||||
@parse_type.setter
|
||||
def parse_type(self, parse_type: Union[str, OmniParseType]):
|
||||
if isinstance(parse_type, str):
|
||||
parse_type = OmniParseType(parse_type)
|
||||
self.parse_options.parse_type = parse_type
|
||||
|
||||
@result_type.setter
|
||||
def result_type(self, result_type: Union[str, ParseResultType]):
|
||||
if isinstance(result_type, str):
|
||||
result_type = ParseResultType(result_type)
|
||||
self.parse_options.result_type = result_type
|
||||
|
||||
async def _aload_data(
|
||||
self,
|
||||
file_path: Union[str, bytes, Path],
|
||||
extra_info: Optional[dict] = None,
|
||||
) -> List[Document]:
|
||||
"""
|
||||
Load data from the input file_path.
|
||||
|
||||
Args:
|
||||
file_path: File path or file byte data.
|
||||
extra_info: Optional dictionary containing additional information.
|
||||
|
||||
Returns:
|
||||
List[Document]
|
||||
"""
|
||||
try:
|
||||
if self.parse_type == OmniParseType.PDF:
|
||||
# pdf parse
|
||||
parsed_result = await self.omniparse_client.parse_pdf(file_path)
|
||||
else:
|
||||
# other parse use omniparse_client.parse_document
|
||||
# For compatible byte data, additional filename is required
|
||||
extra_info = extra_info or {}
|
||||
filename = extra_info.get("filename")
|
||||
parsed_result = await self.omniparse_client.parse_document(file_path, bytes_filename=filename)
|
||||
|
||||
# Get the specified structured data based on result_type
|
||||
content = getattr(parsed_result, self.result_type)
|
||||
docs = [
|
||||
Document(
|
||||
text=content,
|
||||
metadata=extra_info or {},
|
||||
)
|
||||
]
|
||||
except Exception as e:
|
||||
logger.error(f"OMNI Parse Error: {e}")
|
||||
docs = []
|
||||
|
||||
return docs
|
||||
|
||||
async def aload_data(
|
||||
self,
|
||||
file_path: Union[List[FileInput], FileInput],
|
||||
extra_info: Optional[dict] = None,
|
||||
) -> List[Document]:
|
||||
"""
|
||||
Load data from the input file_path.
|
||||
|
||||
Args:
|
||||
file_path: File path or file byte data.
|
||||
extra_info: Optional dictionary containing additional information.
|
||||
|
||||
Notes:
|
||||
This method ultimately calls _aload_data for processing.
|
||||
|
||||
Returns:
|
||||
List[Document]
|
||||
"""
|
||||
docs = []
|
||||
if isinstance(file_path, (str, bytes, Path)):
|
||||
# Processing single file
|
||||
docs = await self._aload_data(file_path, extra_info)
|
||||
elif isinstance(file_path, list):
|
||||
# Concurrently process multiple files
|
||||
parse_jobs = [self._aload_data(file_item, extra_info) for file_item in file_path]
|
||||
doc_ret_list = await run_jobs(jobs=parse_jobs, workers=self.parse_options.num_workers)
|
||||
docs = [doc for docs in doc_ret_list for doc in docs]
|
||||
return docs
|
||||
|
||||
def load_data(
|
||||
self,
|
||||
file_path: Union[List[FileInput], FileInput],
|
||||
extra_info: Optional[dict] = None,
|
||||
) -> List[Document]:
|
||||
"""
|
||||
Load data from the input file_path.
|
||||
|
||||
Args:
|
||||
file_path: File path or file byte data.
|
||||
extra_info: Optional dictionary containing additional information.
|
||||
|
||||
Notes:
|
||||
This method ultimately calls aload_data for processing.
|
||||
|
||||
Returns:
|
||||
List[Document]
|
||||
"""
|
||||
NestAsyncio.apply_once() # Ensure compatibility with nested async calls
|
||||
return asyncio.run(self.aload_data(file_path, extra_info))
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
"""RAG schemas."""
|
||||
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, ClassVar, Literal, Optional, Union
|
||||
from typing import Any, ClassVar, List, Literal, Optional, Union
|
||||
|
||||
from chromadb.api.types import CollectionMetadata
|
||||
from llama_index.core.embeddings import BaseEmbedding
|
||||
|
|
@ -214,3 +214,51 @@ class ObjectNode(TextNode):
|
|||
)
|
||||
|
||||
return metadata.model_dump()
|
||||
|
||||
|
||||
class OmniParseType(str, Enum):
|
||||
"""OmniParseType"""
|
||||
|
||||
PDF = "PDF"
|
||||
DOCUMENT = "DOCUMENT"
|
||||
|
||||
|
||||
class ParseResultType(str, Enum):
|
||||
"""The result type for the parser."""
|
||||
|
||||
TXT = "text"
|
||||
MD = "markdown"
|
||||
JSON = "json"
|
||||
|
||||
|
||||
class OmniParseOptions(BaseModel):
|
||||
"""OmniParse Options config"""
|
||||
|
||||
result_type: ParseResultType = Field(default=ParseResultType.MD, description="OmniParse result_type")
|
||||
parse_type: OmniParseType = Field(default=OmniParseType.DOCUMENT, description="OmniParse parse_type")
|
||||
max_timeout: Optional[int] = Field(default=120, description="Maximum timeout for OmniParse service requests")
|
||||
num_workers: int = Field(
|
||||
default=5,
|
||||
gt=0,
|
||||
lt=10,
|
||||
description="Number of concurrent requests for multiple files",
|
||||
)
|
||||
|
||||
|
||||
class OminParseImage(BaseModel):
|
||||
image: str = Field(default="", description="image str bytes")
|
||||
image_name: str = Field(default="", description="image name")
|
||||
image_info: Optional[dict] = Field(default={}, description="image info")
|
||||
|
||||
|
||||
class OmniParsedResult(BaseModel):
|
||||
markdown: str = Field(default="", description="markdown text")
|
||||
text: str = Field(default="", description="plain text")
|
||||
images: Optional[List[OminParseImage]] = Field(default=[], description="images")
|
||||
metadata: Optional[dict] = Field(default={}, description="metadata")
|
||||
|
||||
@model_validator(mode="before")
|
||||
def set_markdown(cls, values):
|
||||
if not values.get("markdown"):
|
||||
values["markdown"] = values.get("text")
|
||||
return values
|
||||
|
|
|
|||
239
metagpt/utils/omniparse_client.py
Normal file
239
metagpt/utils/omniparse_client.py
Normal file
|
|
@ -0,0 +1,239 @@
|
|||
import mimetypes
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import httpx
|
||||
|
||||
from metagpt.rag.schema import OmniParsedResult
|
||||
from metagpt.utils.common import aread_bin
|
||||
|
||||
|
||||
class OmniParseClient:
|
||||
"""
|
||||
OmniParse Server Client
|
||||
This client interacts with the OmniParse server to parse different types of media, documents.
|
||||
|
||||
OmniParse API Documentation: https://docs.cognitivelab.in/api
|
||||
|
||||
Attributes:
|
||||
ALLOWED_DOCUMENT_EXTENSIONS (set): A set of supported document file extensions.
|
||||
ALLOWED_AUDIO_EXTENSIONS (set): A set of supported audio file extensions.
|
||||
ALLOWED_VIDEO_EXTENSIONS (set): A set of supported video file extensions.
|
||||
"""
|
||||
|
||||
ALLOWED_DOCUMENT_EXTENSIONS = {".pdf", ".ppt", ".pptx", ".doc", ".docx"}
|
||||
ALLOWED_AUDIO_EXTENSIONS = {".mp3", ".wav", ".aac"}
|
||||
ALLOWED_VIDEO_EXTENSIONS = {".mp4", ".mkv", ".avi", ".mov"}
|
||||
|
||||
def __init__(self, api_key: str = None, base_url: str = "http://localhost:8000", max_timeout: int = 120):
|
||||
"""
|
||||
Args:
|
||||
api_key: Default None, can be used for authentication later.
|
||||
base_url: Base URL for the API.
|
||||
max_timeout: Maximum request timeout in seconds.
|
||||
"""
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url
|
||||
self.max_timeout = max_timeout
|
||||
|
||||
self.parse_media_endpoint = "/parse_media"
|
||||
self.parse_website_endpoint = "/parse_website"
|
||||
self.parse_document_endpoint = "/parse_document"
|
||||
|
||||
async def _request_parse(
|
||||
self,
|
||||
endpoint: str,
|
||||
method: str = "POST",
|
||||
files: dict = None,
|
||||
params: dict = None,
|
||||
data: dict = None,
|
||||
json: dict = None,
|
||||
headers: dict = None,
|
||||
**kwargs,
|
||||
) -> dict:
|
||||
"""
|
||||
Request OmniParse API to parse a document.
|
||||
|
||||
Args:
|
||||
endpoint (str): API endpoint.
|
||||
method (str, optional): HTTP method to use. Default is "POST".
|
||||
files (dict, optional): Files to include in the request.
|
||||
params (dict, optional): Query string parameters.
|
||||
data (dict, optional): Form data to include in the request body.
|
||||
json (dict, optional): JSON data to include in the request body.
|
||||
headers (dict, optional): HTTP headers to include in the request.
|
||||
**kwargs: Additional keyword arguments for httpx.AsyncClient.request()
|
||||
|
||||
Returns:
|
||||
dict: JSON response data.
|
||||
"""
|
||||
url = f"{self.base_url}{endpoint}"
|
||||
method = method.upper()
|
||||
headers = headers or {}
|
||||
_headers = {"Authorization": f"Bearer {self.api_key}"} if self.api_key else {}
|
||||
headers.update(**_headers)
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.request(
|
||||
url=url,
|
||||
method=method,
|
||||
files=files,
|
||||
params=params,
|
||||
json=json,
|
||||
data=data,
|
||||
headers=headers,
|
||||
timeout=self.max_timeout,
|
||||
**kwargs,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
async def parse_document(self, file_input: Union[str, bytes, Path], bytes_filename: str = None) -> OmniParsedResult:
|
||||
"""
|
||||
Parse document-type data (supports ".pdf", ".ppt", ".pptx", ".doc", ".docx").
|
||||
|
||||
Args:
|
||||
file_input: File path or file byte data.
|
||||
bytes_filename: Filename for byte data, useful for determining MIME type for the HTTP request.
|
||||
|
||||
Raises:
|
||||
ValueError: If the file extension is not allowed.
|
||||
|
||||
Returns:
|
||||
OmniParsedResult: The result of the document parsing.
|
||||
"""
|
||||
self.verify_file_ext(file_input, self.ALLOWED_DOCUMENT_EXTENSIONS, bytes_filename)
|
||||
file_info = await self.get_file_info(file_input, bytes_filename)
|
||||
resp = await self._request_parse(self.parse_document_endpoint, files={"file": file_info})
|
||||
data = OmniParsedResult(**resp)
|
||||
return data
|
||||
|
||||
async def parse_pdf(self, file_input: Union[str, bytes, Path]) -> OmniParsedResult:
|
||||
"""
|
||||
Parse pdf document.
|
||||
|
||||
Args:
|
||||
file_input: File path or file byte data.
|
||||
|
||||
Raises:
|
||||
ValueError: If the file extension is not allowed.
|
||||
|
||||
Returns:
|
||||
OmniParsedResult: The result of the pdf parsing.
|
||||
"""
|
||||
self.verify_file_ext(file_input, {".pdf"})
|
||||
# parse_pdf supports parsing by accepting only the byte data of the file.
|
||||
file_info = await self.get_file_info(file_input, only_bytes=True)
|
||||
endpoint = f"{self.parse_document_endpoint}/pdf"
|
||||
resp = await self._request_parse(endpoint=endpoint, files={"file": file_info})
|
||||
data = OmniParsedResult(**resp)
|
||||
return data
|
||||
|
||||
async def parse_video(self, file_input: Union[str, bytes, Path], bytes_filename: str = None) -> dict:
|
||||
"""
|
||||
Parse video-type data (supports ".mp4", ".mkv", ".avi", ".mov").
|
||||
|
||||
Args:
|
||||
file_input: File path or file byte data.
|
||||
bytes_filename: Filename for byte data, useful for determining MIME type for the HTTP request.
|
||||
|
||||
Raises:
|
||||
ValueError: If the file extension is not allowed.
|
||||
|
||||
Returns:
|
||||
dict: JSON response data.
|
||||
"""
|
||||
self.verify_file_ext(file_input, self.ALLOWED_VIDEO_EXTENSIONS, bytes_filename)
|
||||
file_info = await self.get_file_info(file_input, bytes_filename)
|
||||
return await self._request_parse(f"{self.parse_media_endpoint}/video", files={"file": file_info})
|
||||
|
||||
async def parse_audio(self, file_input: Union[str, bytes, Path], bytes_filename: str = None) -> dict:
|
||||
"""
|
||||
Parse audio-type data (supports ".mp3", ".wav", ".aac").
|
||||
|
||||
Args:
|
||||
file_input: File path or file byte data.
|
||||
bytes_filename: Filename for byte data, useful for determining MIME type for the HTTP request.
|
||||
|
||||
Raises:
|
||||
ValueError: If the file extension is not allowed.
|
||||
|
||||
Returns:
|
||||
dict: JSON response data.
|
||||
"""
|
||||
self.verify_file_ext(file_input, self.ALLOWED_AUDIO_EXTENSIONS, bytes_filename)
|
||||
file_info = await self.get_file_info(file_input, bytes_filename)
|
||||
return await self._request_parse(f"{self.parse_media_endpoint}/audio", files={"file": file_info})
|
||||
|
||||
@staticmethod
|
||||
def verify_file_ext(file_input: Union[str, bytes, Path], allowed_file_extensions: set, bytes_filename: str = None):
|
||||
"""
|
||||
Verify the file extension.
|
||||
|
||||
Args:
|
||||
file_input: File path or file byte data.
|
||||
allowed_file_extensions: Set of allowed file extensions.
|
||||
bytes_filename: Filename to use for verification when `file_input` is byte data.
|
||||
|
||||
Raises:
|
||||
ValueError: If the file extension is not allowed.
|
||||
|
||||
Returns:
|
||||
"""
|
||||
verify_file_path = None
|
||||
if isinstance(file_input, (str, Path)):
|
||||
verify_file_path = str(file_input)
|
||||
elif isinstance(file_input, bytes) and bytes_filename:
|
||||
verify_file_path = bytes_filename
|
||||
|
||||
if not verify_file_path:
|
||||
# Do not verify if only byte data is provided
|
||||
return
|
||||
|
||||
file_ext = os.path.splitext(verify_file_path)[1].lower()
|
||||
if file_ext not in allowed_file_extensions:
|
||||
raise ValueError(f"Not allowed {file_ext} File extension must be one of {allowed_file_extensions}")
|
||||
|
||||
@staticmethod
|
||||
async def get_file_info(
|
||||
file_input: Union[str, bytes, Path],
|
||||
bytes_filename: str = None,
|
||||
only_bytes: bool = False,
|
||||
) -> Union[bytes, tuple]:
|
||||
"""
|
||||
Get file information.
|
||||
|
||||
Args:
|
||||
file_input: File path or file byte data.
|
||||
bytes_filename: Filename to use when uploading byte data, useful for determining MIME type.
|
||||
only_bytes: Whether to return only byte data. Default is False, which returns a tuple.
|
||||
|
||||
Raises:
|
||||
ValueError: If bytes_filename is not provided when file_input is bytes or if file_input is not a valid type.
|
||||
|
||||
Notes:
|
||||
Since `parse_document`,`parse_video`, `parse_audio` supports parsing various file types,
|
||||
the MIME type of the file must be specified when uploading.
|
||||
|
||||
Returns: [bytes, tuple]
|
||||
Returns bytes if only_bytes is True, otherwise returns a tuple (filename, file_bytes, mime_type).
|
||||
"""
|
||||
if isinstance(file_input, (str, Path)):
|
||||
filename = os.path.basename(str(file_input))
|
||||
file_bytes = await aread_bin(file_input)
|
||||
|
||||
if only_bytes:
|
||||
return file_bytes
|
||||
|
||||
mime_type = mimetypes.guess_type(file_input)[0]
|
||||
return filename, file_bytes, mime_type
|
||||
elif isinstance(file_input, bytes):
|
||||
if only_bytes:
|
||||
return file_input
|
||||
if not bytes_filename:
|
||||
raise ValueError("bytes_filename must be set when passing bytes")
|
||||
|
||||
mime_type = mimetypes.guess_type(bytes_filename)[0]
|
||||
return bytes_filename, file_input, mime_type
|
||||
else:
|
||||
raise ValueError("file_input must be a string (file path) or bytes.")
|
||||
|
|
@ -7,6 +7,7 @@ from llama_index.core.llms import MockLLM
|
|||
from llama_index.core.schema import Document, NodeWithScore, TextNode
|
||||
|
||||
from metagpt.rag.engines import SimpleEngine
|
||||
from metagpt.rag.parsers import OmniParse
|
||||
from metagpt.rag.retrievers import SimpleHybridRetriever
|
||||
from metagpt.rag.retrievers.base import ModifiableRAGRetriever, PersistableRAGRetriever
|
||||
from metagpt.rag.schema import BM25RetrieverConfig, ObjectNode
|
||||
|
|
@ -37,6 +38,10 @@ class TestSimpleEngine:
|
|||
def mock_get_response_synthesizer(self, mocker):
|
||||
return mocker.patch("metagpt.rag.engines.simple.get_response_synthesizer")
|
||||
|
||||
@pytest.fixture
|
||||
def mock_get_file_extractor(self, mocker):
|
||||
return mocker.patch("metagpt.rag.engines.simple.SimpleEngine._get_file_extractor")
|
||||
|
||||
def test_from_docs(
|
||||
self,
|
||||
mocker,
|
||||
|
|
@ -44,6 +49,7 @@ class TestSimpleEngine:
|
|||
mock_get_retriever,
|
||||
mock_get_rankers,
|
||||
mock_get_response_synthesizer,
|
||||
mock_get_file_extractor,
|
||||
):
|
||||
# Mock
|
||||
mock_simple_directory_reader.return_value.load_data.return_value = [
|
||||
|
|
@ -53,6 +59,8 @@ class TestSimpleEngine:
|
|||
mock_get_retriever.return_value = mocker.MagicMock()
|
||||
mock_get_rankers.return_value = [mocker.MagicMock()]
|
||||
mock_get_response_synthesizer.return_value = mocker.MagicMock()
|
||||
file_extractor = mocker.MagicMock()
|
||||
mock_get_file_extractor.return_value = file_extractor
|
||||
|
||||
# Setup
|
||||
input_dir = "test_dir"
|
||||
|
|
@ -75,7 +83,9 @@ class TestSimpleEngine:
|
|||
)
|
||||
|
||||
# Assert
|
||||
mock_simple_directory_reader.assert_called_once_with(input_dir=input_dir, input_files=input_files)
|
||||
mock_simple_directory_reader.assert_called_once_with(
|
||||
input_dir=input_dir, input_files=input_files, file_extractor=file_extractor
|
||||
)
|
||||
mock_get_retriever.assert_called_once()
|
||||
mock_get_rankers.assert_called_once()
|
||||
mock_get_response_synthesizer.assert_called_once_with(llm=llm)
|
||||
|
|
@ -298,3 +308,17 @@ class TestSimpleEngine:
|
|||
# Assert
|
||||
assert "obj" in node.node.metadata
|
||||
assert node.node.metadata["obj"] == expected_obj
|
||||
|
||||
def test_get_file_extractor(self, mocker):
|
||||
# mock no omniparse config
|
||||
mock_omniparse_config = mocker.patch("metagpt.rag.engines.simple.config.omniparse", autospec=True)
|
||||
mock_omniparse_config.base_url = ""
|
||||
|
||||
file_extractor = SimpleEngine._get_file_extractor()
|
||||
assert file_extractor == {}
|
||||
|
||||
# mock have omniparse config
|
||||
mock_omniparse_config.base_url = "http://localhost:8000"
|
||||
file_extractor = SimpleEngine._get_file_extractor()
|
||||
assert ".pdf" in file_extractor
|
||||
assert isinstance(file_extractor[".pdf"], OmniParse)
|
||||
|
|
|
|||
118
tests/metagpt/rag/parser/test_omniparse.py
Normal file
118
tests/metagpt/rag/parser/test_omniparse.py
Normal file
|
|
@ -0,0 +1,118 @@
|
|||
import pytest
|
||||
from llama_index.core import Document
|
||||
|
||||
from metagpt.const import EXAMPLE_DATA_PATH
|
||||
from metagpt.rag.parsers import OmniParse
|
||||
from metagpt.rag.schema import (
|
||||
OmniParsedResult,
|
||||
OmniParseOptions,
|
||||
OmniParseType,
|
||||
ParseResultType,
|
||||
)
|
||||
from metagpt.utils.omniparse_client import OmniParseClient
|
||||
|
||||
# test data
|
||||
TEST_DOCX = EXAMPLE_DATA_PATH / "omniparse/test01.docx"
|
||||
TEST_PDF = EXAMPLE_DATA_PATH / "omniparse/test02.pdf"
|
||||
TEST_VIDEO = EXAMPLE_DATA_PATH / "omniparse/test03.mp4"
|
||||
TEST_AUDIO = EXAMPLE_DATA_PATH / "omniparse/test04.mp3"
|
||||
|
||||
|
||||
class TestOmniParseClient:
|
||||
parse_client = OmniParseClient()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_request_parse(self, mocker):
|
||||
return mocker.patch("metagpt.rag.parsers.omniparse.OmniParseClient._request_parse")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_pdf(self, mock_request_parse):
|
||||
mock_content = "#test title\ntest content"
|
||||
mock_parsed_ret = OmniParsedResult(text=mock_content, markdown=mock_content)
|
||||
mock_request_parse.return_value = mock_parsed_ret.model_dump()
|
||||
parse_ret = await self.parse_client.parse_pdf(TEST_PDF)
|
||||
assert parse_ret == mock_parsed_ret
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_document(self, mock_request_parse):
|
||||
mock_content = "#test title\ntest_parse_document"
|
||||
mock_parsed_ret = OmniParsedResult(text=mock_content, markdown=mock_content)
|
||||
mock_request_parse.return_value = mock_parsed_ret.model_dump()
|
||||
|
||||
with open(TEST_DOCX, "rb") as f:
|
||||
file_bytes = f.read()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
# bytes data must provide bytes_filename
|
||||
await self.parse_client.parse_document(file_bytes)
|
||||
|
||||
parse_ret = await self.parse_client.parse_document(file_bytes, bytes_filename="test.docx")
|
||||
assert parse_ret == mock_parsed_ret
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_video(self, mock_request_parse):
|
||||
mock_content = "#test title\ntest_parse_video"
|
||||
mock_request_parse.return_value = {
|
||||
"text": mock_content,
|
||||
"metadata": {},
|
||||
}
|
||||
with pytest.raises(ValueError):
|
||||
# Wrong file extension test
|
||||
await self.parse_client.parse_video(TEST_DOCX)
|
||||
|
||||
parse_ret = await self.parse_client.parse_video(TEST_VIDEO)
|
||||
assert "text" in parse_ret and "metadata" in parse_ret
|
||||
assert parse_ret["text"] == mock_content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_audio(self, mock_request_parse):
|
||||
mock_content = "#test title\ntest_parse_audio"
|
||||
mock_request_parse.return_value = {
|
||||
"text": mock_content,
|
||||
"metadata": {},
|
||||
}
|
||||
parse_ret = await self.parse_client.parse_audio(TEST_AUDIO)
|
||||
assert "text" in parse_ret and "metadata" in parse_ret
|
||||
assert parse_ret["text"] == mock_content
|
||||
|
||||
|
||||
class TestOmniParse:
|
||||
@pytest.fixture
|
||||
def mock_omniparse(self):
|
||||
parser = OmniParse(
|
||||
parse_options=OmniParseOptions(
|
||||
parse_type=OmniParseType.PDF,
|
||||
result_type=ParseResultType.MD,
|
||||
max_timeout=120,
|
||||
num_workers=3,
|
||||
)
|
||||
)
|
||||
return parser
|
||||
|
||||
@pytest.fixture
|
||||
def mock_request_parse(self, mocker):
|
||||
return mocker.patch("metagpt.rag.parsers.omniparse.OmniParseClient._request_parse")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_data(self, mock_omniparse, mock_request_parse):
|
||||
# mock
|
||||
mock_content = "#test title\ntest content"
|
||||
mock_parsed_ret = OmniParsedResult(text=mock_content, markdown=mock_content)
|
||||
mock_request_parse.return_value = mock_parsed_ret.model_dump()
|
||||
|
||||
# single file
|
||||
documents = mock_omniparse.load_data(file_path=TEST_PDF)
|
||||
doc = documents[0]
|
||||
assert isinstance(doc, Document)
|
||||
assert doc.text == mock_parsed_ret.text == mock_parsed_ret.markdown
|
||||
|
||||
# multi files
|
||||
file_paths = [TEST_DOCX, TEST_PDF]
|
||||
mock_omniparse.parse_type = OmniParseType.DOCUMENT
|
||||
documents = await mock_omniparse.aload_data(file_path=file_paths)
|
||||
doc = documents[0]
|
||||
|
||||
# assert
|
||||
assert isinstance(doc, Document)
|
||||
assert len(documents) == len(file_paths)
|
||||
assert doc.text == mock_parsed_ret.text == mock_parsed_ret.markdown
|
||||
Loading…
Add table
Add a link
Reference in a new issue