diff --git a/config/config2.example.yaml b/config/config2.example.yaml index 64cce630f..23ef79555 100644 --- a/config/config2.example.yaml +++ b/config/config2.example.yaml @@ -59,3 +59,7 @@ iflytek_api_key: "YOUR_API_KEY" iflytek_api_secret: "YOUR_API_SECRET" metagpt_tti_url: "YOUR_MODEL_URL" + +omniparse: + api_key: "YOUR_API_KEY" + base_url: "YOUR_BASE_URL" diff --git a/config/config2.yaml b/config/config2.yaml index 7ae8525f5..b3f24539c 100644 --- a/config/config2.yaml +++ b/config/config2.yaml @@ -5,8 +5,4 @@ llm: api_type: "openai" # or azure / ollama / groq etc. model: "gpt-4-turbo" # or gpt-3.5-turbo base_url: "https://api.openai.com/v1" # or forward url / other llm url - api_key: "xxxx" - -omniparse: - api_key: "your_api_key" - base_url: "http://192.168.50.126:8000" \ No newline at end of file + api_key: "YOUR_API_KEY" \ No newline at end of file diff --git a/examples/data/parse/test01.docx b/examples/data/omniparse/test01.docx similarity index 100% rename from examples/data/parse/test01.docx rename to examples/data/omniparse/test01.docx diff --git a/examples/data/parse/test02.pdf b/examples/data/omniparse/test02.pdf similarity index 100% rename from examples/data/parse/test02.pdf rename to examples/data/omniparse/test02.pdf diff --git a/examples/data/parse/test03.mp4 b/examples/data/omniparse/test03.mp4 similarity index 100% rename from examples/data/parse/test03.mp4 rename to examples/data/omniparse/test03.mp4 diff --git a/examples/data/parse/test04.mp3 b/examples/data/omniparse/test04.mp3 similarity index 100% rename from examples/data/parse/test04.mp3 rename to examples/data/omniparse/test04.mp3 diff --git a/examples/rag/omniparse_client.py b/examples/rag/omniparse.py similarity index 68% rename from examples/rag/omniparse_client.py rename to examples/rag/omniparse.py index a528e535f..7cea8f954 100644 --- a/examples/rag/omniparse_client.py +++ b/examples/rag/omniparse.py @@ -1,18 +1,16 @@ import asyncio -from llama_parse import ResultType - from metagpt.config2 import config -from metagpt.logs import logger -from metagpt.rag.parser.omniparse.client import OmniParseClient -from metagpt.rag.parser.omniparse.parse import OmniParse -from metagpt.rag.schema import OmniParseOptions, OmniParseType from metagpt.const import EXAMPLE_DATA_PATH +from metagpt.logs import logger +from metagpt.rag.parser import OmniParse +from metagpt.rag.schema import OmniParseOptions, OmniParseType, ParseResultType +from metagpt.utils.omniparse_client import OmniParseClient -TEST_DOCX = EXAMPLE_DATA_PATH / "parse/test01.docx" -TEST_PDF = EXAMPLE_DATA_PATH / "parse/test02.pdf" -TEST_VIDEO = EXAMPLE_DATA_PATH / "parse/test03.mp4" -TEST_AUDIO = EXAMPLE_DATA_PATH / "parse/test04.mp3" +TEST_DOCX = EXAMPLE_DATA_PATH / "omniparse/test01.docx" +TEST_PDF = EXAMPLE_DATA_PATH / "omniparse/test02.pdf" +TEST_VIDEO = EXAMPLE_DATA_PATH / "omniparse/test03.mp4" +TEST_AUDIO = EXAMPLE_DATA_PATH / "omniparse/test04.mp3" TEST_WEBSITE_URL = "https://github.com/geekan/MetaGPT" @@ -37,10 +35,6 @@ async def omniparse_client_example(): audio_parse_ret = await client.parse_audio(filelike=TEST_AUDIO) logger.info(audio_parse_ret) - # website fixme:omniparse官方api还存在问题 - # website_parse_ret = await client.parse_website(url=TEST_WEBSITE_URL) - # logger.info(website_parse_ret) - async def omniparse_example(): parser = OmniParse( @@ -48,10 +42,10 @@ async def omniparse_example(): base_url=config.omniparse.base_url, parse_options=OmniParseOptions( parse_type=OmniParseType.PDF, - result_type=ResultType.MD, + result_type=ParseResultType.MD, max_timeout=120, num_workers=3, - ) + ), ) ret = parser.load_data(file_path=TEST_PDF) logger.info(ret) @@ -67,5 +61,5 @@ async def main(): await omniparse_example() -if __name__ == '__main__': +if __name__ == "__main__": asyncio.run(main()) diff --git a/metagpt/config2.py b/metagpt/config2.py index 3aaad28e4..96b677b65 100644 --- a/metagpt/config2.py +++ b/metagpt/config2.py @@ -12,7 +12,8 @@ from typing import Dict, Iterable, List, Literal, Optional from pydantic import BaseModel, model_validator from metagpt.configs.browser_config import BrowserConfig -from metagpt.configs.embedding_config import EmbeddingConfig, OmniParseConfig +from metagpt.configs.embedding_config import EmbeddingConfig +from metagpt.configs.file_parser_config import OmniParseConfig from metagpt.configs.llm_config import LLMConfig, LLMType from metagpt.configs.mermaid_config import MermaidConfig from metagpt.configs.redis_config import RedisConfig diff --git a/metagpt/configs/embedding_config.py b/metagpt/configs/embedding_config.py index bc7411274..f9b41b9dc 100644 --- a/metagpt/configs/embedding_config.py +++ b/metagpt/configs/embedding_config.py @@ -52,8 +52,3 @@ class EmbeddingConfig(YamlModel): if v == "": return None return v - - -class OmniParseConfig(YamlModel): - api_key: str = "" - base_url: str = "" diff --git a/metagpt/configs/file_parser_config.py b/metagpt/configs/file_parser_config.py new file mode 100644 index 000000000..39742c8a4 --- /dev/null +++ b/metagpt/configs/file_parser_config.py @@ -0,0 +1,6 @@ +from metagpt.utils.yaml_model import YamlModel + + +class OmniParseConfig(YamlModel): + api_key: str = "" + base_url: str = "" diff --git a/metagpt/rag/engines/simple.py b/metagpt/rag/engines/simple.py index 3bb665d10..f3f86111d 100644 --- a/metagpt/rag/engines/simple.py +++ b/metagpt/rag/engines/simple.py @@ -14,6 +14,7 @@ from llama_index.core.llms import LLM from llama_index.core.node_parser import SentenceSplitter from llama_index.core.postprocessor.types import BaseNodePostprocessor from llama_index.core.query_engine import RetrieverQueryEngine +from llama_index.core.readers.base import BaseReader from llama_index.core.response_synthesizers import ( BaseSynthesizer, get_response_synthesizer, @@ -27,7 +28,6 @@ from llama_index.core.schema import ( QueryType, TransformComponent, ) -from llama_parse import ResultType from metagpt.config2 import config from metagpt.rag.factories import ( @@ -38,7 +38,7 @@ from metagpt.rag.factories import ( get_retriever, ) from metagpt.rag.interface import NoEmbedding, RAGObject -from metagpt.rag.parser.omniparse.parse import OmniParse +from metagpt.rag.parser import OmniParse from metagpt.rag.retrievers.base import ModifiableRAGRetriever, PersistableRAGRetriever from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever from metagpt.rag.schema import ( @@ -46,7 +46,10 @@ from metagpt.rag.schema import ( BaseRankerConfig, BaseRetrieverConfig, BM25RetrieverConfig, - ObjectNode, OmniParseOptions, OmniParseType, + ObjectNode, + OmniParseOptions, + OmniParseType, + ParseResultType, ) from metagpt.utils.common import import_class @@ -76,18 +79,6 @@ class SimpleEngine(RetrieverQueryEngine): ) self._transformations = transformations or self._default_transformations() - @classmethod - def get_file_extractor(cls, file_type: str): - if not config.omniparse.base_url: - return - parser = OmniParse( - api_key=config.omniparse.api_key, - base_url=config.omniparse.base_url, - parse_options=OmniParseOptions(parse_type=OmniParseType.PDF, result_type=ResultType.MD) - ) - file_extractor = {file_type: parser} - return file_extractor - @classmethod def from_docs( cls, @@ -115,7 +106,7 @@ class SimpleEngine(RetrieverQueryEngine): if not input_dir and not input_files: raise ValueError("Must provide either `input_dir` or `input_files`.") - file_extractor = cls.get_file_extractor(file_type=".pdf") + file_extractor = cls._get_file_extractor(file_type=".pdf") documents = SimpleDirectoryReader( input_dir=input_dir, input_files=input_files, file_extractor=file_extractor ).load_data() @@ -319,3 +310,31 @@ class SimpleEngine(RetrieverQueryEngine): @staticmethod def _default_transformations(): return [SentenceSplitter()] + + @staticmethod + def _get_file_extractor(file_type: str = None) -> dict[str:BaseReader]: + """ + Get the file extractor for a specified file type. + If no file type is provided, return all available extractors. + Currently, only OmniParse PDF extraction is supported. + + Args: + file_type: The type of file for which the extractor is needed. Defaults to None. + + Returns: + dict[file_type: BaseReader] + """ + file_extractor_mapping: dict[str:BaseReader] = {} + if config.omniparse.base_url: + pdf_parser = OmniParse( + api_key=config.omniparse.api_key, + base_url=config.omniparse.base_url, + parse_options=OmniParseOptions(parse_type=OmniParseType.PDF, result_type=ParseResultType.MD), + ) + file_extractor_mapping[".pdf"] = pdf_parser + + if file_type: + file_extractor = file_extractor_mapping.get(file_type) + return {file_type: file_extractor} if file_extractor else {} + + return file_extractor_mapping diff --git a/metagpt/rag/parser/__init__.py b/metagpt/rag/parser/__init__.py index e69de29bb..90d7eb971 100644 --- a/metagpt/rag/parser/__init__.py +++ b/metagpt/rag/parser/__init__.py @@ -0,0 +1,3 @@ +from metagpt.rag.parser.omniparse import OmniParse + +__all__ = ["OmniParse"] diff --git a/metagpt/rag/parser/omniparse/parse.py b/metagpt/rag/parser/omniparse.py similarity index 54% rename from metagpt/rag/parser/omniparse/parse.py rename to metagpt/rag/parser/omniparse.py index 176620934..85227dc06 100644 --- a/metagpt/rag/parser/omniparse/parse.py +++ b/metagpt/rag/parser/omniparse.py @@ -1,28 +1,31 @@ import asyncio from fileinput import FileInput from pathlib import Path -from typing import List, Union, Optional +from typing import List, Optional, Union from llama_index.core import Document from llama_index.core.async_utils import run_jobs from llama_index.core.readers.base import BaseReader from llama_parse import ResultType -from metagpt.rag.parser.omniparse.client import OmniParseClient -from metagpt.rag.schema import OmniParseOptions, OmniParseType from metagpt.logs import logger +from metagpt.rag.schema import OmniParseOptions, OmniParseType from metagpt.utils.async_helper import NestAsyncio +from metagpt.utils.omniparse_client import OmniParseClient class OmniParse(BaseReader): """OmniParse""" def __init__( - self, - api_key=None, - base_url="http://localhost:8000", - parse_options: OmniParseOptions = None + self, api_key: str = None, base_url: str = "http://localhost:8000", parse_options: OmniParseOptions = None ): + """ + Args: + api_key: Default None, can be used for authentication later. + base_url: OmniParse Base URL for the API. + parse_options: Optional settings for OmniParse. Default is OmniParseOptions with default values. + """ self.parse_options = parse_options or OmniParseOptions() self.omniparse_client = OmniParseClient(api_key, base_url, max_timeout=self.parse_options.max_timeout) @@ -47,20 +50,32 @@ class OmniParse(BaseReader): self.parse_options.result_type = result_type async def _aload_data( - self, - file_path: Union[str, bytes, Path], - extra_info: Optional[dict] = None, + self, + file_path: Union[str, bytes, Path], + extra_info: Optional[dict] = None, ) -> List[Document]: + """ + Load data from the input file_path. + + Args: + file_path: File path or file byte data. + extra_info: Optional dictionary containing additional information. + + Returns: + List[Document] + """ try: if self.parse_type == OmniParseType.PDF: - # 目前先只支持pdf解析 + # pdf parse parsed_result = await self.omniparse_client.parse_pdf(file_path) else: + # other parse use omniparse_client.parse_document + # For compatible byte data, additional filename is required extra_info = extra_info or {} - filename = extra_info.get("filename") # 兼容字节数据要额外传filename + filename = extra_info.get("filename") parsed_result = await self.omniparse_client.parse_document(file_path, bytes_filename=filename) - # 获取指定的结构数据 + # Get the specified structured data based on result_type content = getattr(parsed_result, self.result_type) docs = [ Document( @@ -75,26 +90,51 @@ class OmniParse(BaseReader): return docs async def aload_data( - self, - file_path: Union[List[FileInput], FileInput], - extra_info: Optional[dict] = None, + self, + file_path: Union[List[FileInput], FileInput], + extra_info: Optional[dict] = None, ) -> List[Document]: + """ + Load data from the input file_path. + + Args: + file_path: File path or file byte data. + extra_info: Optional dictionary containing additional information. + + Notes: + This method ultimately calls _aload_data for processing. + + Returns: + List[Document] + """ docs = [] if isinstance(file_path, (str, bytes, Path)): - # 处理单个 + # Processing single file docs = await self._aload_data(file_path, extra_info) elif isinstance(file_path, list): - # 并发处理多个 + # Concurrently process multiple files parse_jobs = [self._aload_data(file_item, extra_info) for file_item in file_path] doc_ret_list = await run_jobs(jobs=parse_jobs, workers=self.parse_options.num_workers) docs = [doc for docs in doc_ret_list for doc in docs] return docs def load_data( - self, - file_path: Union[List[FileInput], FileInput], - extra_info: Optional[dict] = None, + self, + file_path: Union[List[FileInput], FileInput], + extra_info: Optional[dict] = None, ) -> List[Document]: - """Load data from the input path.""" - NestAsyncio.apply_once() # 兼容异步嵌套调用 + """ + Load data from the input file_path. + + Args: + file_path: File path or file byte data. + extra_info: Optional dictionary containing additional information. + + Notes: + This method ultimately calls aload_data for processing. + + Returns: + List[Document] + """ + NestAsyncio.apply_once() # Ensure compatibility with nested async calls return asyncio.run(self.aload_data(file_path, extra_info)) diff --git a/metagpt/rag/parser/omniparse/__init__.py b/metagpt/rag/parser/omniparse/__init__.py deleted file mode 100644 index d453d14d6..000000000 --- a/metagpt/rag/parser/omniparse/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .client import OmniParseClient -from .parse import OmniParse diff --git a/metagpt/rag/schema.py b/metagpt/rag/schema.py index 7f34a0be9..b02ef16cc 100644 --- a/metagpt/rag/schema.py +++ b/metagpt/rag/schema.py @@ -8,7 +8,6 @@ from llama_index.core.embeddings import BaseEmbedding from llama_index.core.indices.base import BaseIndex from llama_index.core.schema import TextNode from llama_index.core.vector_stores.types import VectorStoreQueryMode -from llama_parse import ResultType from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator from metagpt.config2 import config @@ -218,23 +217,31 @@ class ObjectNode(TextNode): class OmniParseType(str, Enum): - """OmniParse解析类型""" + """OmniParseType""" PDF = "PDF" DOCUMENT = "DOCUMENT" +class ParseResultType(str, Enum): + """The result type for the parser.""" + + TXT = "text" + MD = "markdown" + JSON = "json" + + class OmniParseOptions(BaseModel): """OmniParse可选配置""" - result_type: ResultType = Field(default=ResultType.MD, description="OmniParse解析返回的结果类型") - parse_type: OmniParseType = Field(default=OmniParseType.DOCUMENT, description="OmniParse解析类型,默认文档类型") - max_timeout: Optional[int] = Field(default=120, description="OmniParse服务请求最大超时") + result_type: ParseResultType = Field(default=ParseResultType.MD, description="OmniParse result_type") + parse_type: OmniParseType = Field(default=OmniParseType.DOCUMENT, description="OmniParse parse_type") + max_timeout: Optional[int] = Field(default=120, description="Maximum timeout for OmniParse service requests") num_workers: int = Field( default=5, gt=0, lt=10, - description="多文件列表时并发请求数量", + description="Number of concurrent requests for multiple files", ) diff --git a/metagpt/rag/parser/omniparse/client.py b/metagpt/utils/omniparse_client.py similarity index 55% rename from metagpt/rag/parser/omniparse/client.py rename to metagpt/utils/omniparse_client.py index 7386bff0d..787d6fe88 100644 --- a/metagpt/rag/parser/omniparse/client.py +++ b/metagpt/utils/omniparse_client.py @@ -3,28 +3,35 @@ import os from pathlib import Path from typing import Union -import aiofiles import httpx from metagpt.rag.schema import OmniParsedResult +from metagpt.utils.common import aread_bin class OmniParseClient: """ OmniParse Server Client - OmniParse API Docs: https://docs.cognitivelab.in/api + This client interacts with the OmniParse server to parse different types of media, documents, and websites. + + OmniParse API Documentation: https://docs.cognitivelab.in/api + + Attributes: + ALLOWED_DOCUMENT_EXTENSIONS (set): A set of supported document file extensions. + ALLOWED_AUDIO_EXTENSIONS (set): A set of supported audio file extensions. + ALLOWED_VIDEO_EXTENSIONS (set): A set of supported video file extensions. """ ALLOWED_DOCUMENT_EXTENSIONS = {".pdf", ".ppt", ".pptx", ".doc", ".docx"} ALLOWED_AUDIO_EXTENSIONS = {".mp3", ".wav", ".aac"} ALLOWED_VIDEO_EXTENSIONS = {".mp4", ".mkv", ".avi", ".mov"} - def __init__(self, api_key=None, base_url="http://localhost:8000", max_timeout=120): + def __init__(self, api_key: str = None, base_url: str = "http://localhost:8000", max_timeout: int = 120): """ Args: - api_key: 默认 None 后续可用于鉴权 - base_url: api 基础url - max_timeout: 请求最大超时时间单位s + api_key: Default None, can be used for authentication later. + base_url: Base URL for the API. + max_timeout: Maximum request timeout in seconds. """ self.api_key = api_key self.base_url = base_url @@ -46,19 +53,20 @@ class OmniParseClient: **kwargs, ) -> dict: """ - 请求api解析文档 + Request OmniParse API to parse a document. Args: endpoint (str): API endpoint. - files (dict, optional): 请求文件数据. - params (dict, optional): 查询字符串参数. - data (dict, optional): 请求体数据. - json (dict, optional): 请求json数据. - headers (dict, optional): 请求头数据. - **kwargs: 其他 httpx.AsyncClient.request() 关键字参数 + method (str, optional): HTTP method to use. Default is "POST". + files (dict, optional): Files to include in the request. + params (dict, optional): Query string parameters. + data (dict, optional): Form data to include in the request body. + json (dict, optional): JSON data to include in the request body. + headers (dict, optional): HTTP headers to include in the request. + **kwargs: Additional keyword arguments for httpx.AsyncClient.request() Returns: - dict: 响应的json数据 + dict: JSON response data. """ url = f"{self.base_url}{endpoint}" method = method.upper() @@ -80,17 +88,94 @@ class OmniParseClient: response.raise_for_status() return response.json() + async def parse_document(self, filelike: Union[str, bytes, Path], bytes_filename: str = None) -> OmniParsedResult: + """ + Parse document-type data (supports ".pdf", ".ppt", ".pptx", ".doc", ".docx"). + + Args: + filelike: File path or file byte data. + bytes_filename: Filename for byte data, useful for determining MIME type for the HTTP request. + + Raises: + ValueError: If the file extension is not allowed. + + Returns: + OmniParsedResult: The result of the document parsing. + """ + self.verify_file_ext(filelike, self.ALLOWED_DOCUMENT_EXTENSIONS, bytes_filename) + file_info = await self.get_file_info(filelike, bytes_filename) + resp = await self._request_parse(self.parse_document_endpoint, files={"file": file_info}) + data = OmniParsedResult(**resp) + return data + + async def parse_pdf(self, filelike: Union[str, bytes, Path]) -> OmniParsedResult: + """ + Parse pdf document. + + Args: + filelike: File path or file byte data. + + Raises: + ValueError: If the file extension is not allowed. + + Returns: + OmniParsedResult: The result of the pdf parsing. + """ + self.verify_file_ext(filelike, {".pdf"}) + file_info = await self.get_file_info(filelike, only_bytes=True) + endpoint = f"{self.parse_document_endpoint}/pdf" + resp = await self._request_parse(endpoint=endpoint, files={"file": file_info}) + data = OmniParsedResult(**resp) + return data + + async def parse_video(self, filelike: Union[str, bytes, Path], bytes_filename: str = None) -> dict: + """ + Parse video-type data (supports ".mp4", ".mkv", ".avi", ".mov"). + + Args: + filelike: File path or file byte data. + bytes_filename: Filename for byte data, useful for determining MIME type for the HTTP request. + + Raises: + ValueError: If the file extension is not allowed. + + Returns: + dict: JSON response data. + """ + self.verify_file_ext(filelike, self.ALLOWED_VIDEO_EXTENSIONS, bytes_filename) + file_info = await self.get_file_info(filelike, bytes_filename) + return await self._request_parse(f"{self.parse_media_endpoint}/video", files={"file": file_info}) + + async def parse_audio(self, filelike: Union[str, bytes, Path], bytes_filename: str = None) -> dict: + """ + Parse audio-type data (supports ".mp3", ".wav", ".aac"). + + Args: + filelike: File path or file byte data. + bytes_filename: Filename for byte data, useful for determining MIME type for the HTTP request. + + Raises: + ValueError: If the file extension is not allowed. + + Returns: + dict: JSON response data. + """ + self.verify_file_ext(filelike, self.ALLOWED_AUDIO_EXTENSIONS, bytes_filename) + file_info = await self.get_file_info(filelike, bytes_filename) + return await self._request_parse(f"{self.parse_media_endpoint}/audio", files={"file": file_info}) + @staticmethod def verify_file_ext(filelike: Union[str, bytes, Path], allowed_file_extensions: set, bytes_filename: str = None): """ - 校验文件后缀 + Verify the file extension. + Args: - filelike: 文件路径 or 文件字节数据 - allowed_file_extensions: 允许的文件扩展名 - bytes_filename: 当字节数据时使用这个参数校验 + filelike: File path or file byte data. + allowed_file_extensions: Set of allowed file extensions. + bytes_filename: Filename to use for verification when `filelike` is byte data. Raises: - ValueError + ValueError: If the file extension is not allowed. Returns: """ @@ -101,7 +186,7 @@ class OmniParseClient: verify_file_path = bytes_filename if not verify_file_path: - # 仅传字节数据时不校验 + # Do not verify if only byte data is provided return file_ext = os.path.splitext(verify_file_path)[1].lower() @@ -112,28 +197,29 @@ class OmniParseClient: async def get_file_info( filelike: Union[str, bytes, Path], bytes_filename: str = None, - only_bytes=True, + only_bytes: bool = False, ) -> Union[bytes, tuple]: """ - 获取文件字节信息 + Get file information. + Args: - filelike: 文件数据 - bytes_filename: 通过字节数据上传需要指定文件名称,方便获取mime_type - only_bytes: 是否只需要字节数据 + filelike: File path or file byte data. + bytes_filename: Filename to use when uploading byte data, useful for determining MIME type. + only_bytes: Whether to return only byte data. Default is False, which returns a tuple. Raises: - ValueError + ValueError: If bytes_filename is not provided when filelike is bytes or if filelike is not a valid type. Notes: - 由于 parse_document 支持多种文件解析,需要上传文件时指定文件的mime_type + Since `parse_document`,`parse_video`, `parse_audio` supports parsing various file types, + the MIME type of the file must be specified when uploading. - Returns: - [bytes, tuple] + Returns: [bytes, tuple] + Returns bytes if only_bytes is True, otherwise returns a tuple (filename, file_bytes, mime_type). """ if isinstance(filelike, (str, Path)): filename = os.path.basename(str(filelike)) - async with aiofiles.open(filelike, "rb") as file: - file_bytes = await file.read() + file_bytes = await aread_bin(filelike) if only_bytes: return file_bytes @@ -150,60 +236,3 @@ class OmniParseClient: return bytes_filename, filelike, mime_type else: raise ValueError("filelike must be a string (file path) or bytes.") - - async def parse_document(self, filelike: Union[str, bytes, Path], bytes_filename: str = None) -> OmniParsedResult: - """ - 解析文档类型数据(支持 ".pdf", ".ppt", ".pptx", ".doc", ".docs") - Args: - filelike: 文件路径 or 文件字节数据 - bytes_filename: 字节数据名称,方便获取mime_type 用于httpx请求 - - Raises - ValueError - - Returns: - OmniParsedResult - """ - self.verify_file_ext(filelike, self.ALLOWED_DOCUMENT_EXTENSIONS, bytes_filename) - file_info = await self.get_file_info(filelike, bytes_filename, only_bytes=False) - resp = await self._request_parse(self.parse_document_endpoint, files={"file": file_info}) - data = OmniParsedResult(**resp) - return data - - async def parse_pdf(self, filelike: Union[str, bytes, Path]) -> OmniParsedResult: - """ - 解析PDF文档 - Args: - filelike: 文件路径 or 文件字节数据 - - Raises - ValueError - - Returns: - OmniParsedResult - """ - self.verify_file_ext(filelike, {".pdf"}) - file_info = await self.get_file_info(filelike) - endpoint = f"{self.parse_document_endpoint}/pdf" - resp = await self._request_parse(endpoint=endpoint, files={"file": file_info}) - data = OmniParsedResult(**resp) - return data - - async def parse_video(self, filelike: Union[str, bytes, Path], bytes_filename: str = None) -> dict: - """解析视频""" - self.verify_file_ext(filelike, self.ALLOWED_VIDEO_EXTENSIONS, bytes_filename) - file_info = await self.get_file_info(filelike, bytes_filename, only_bytes=False) - return await self._request_parse(f"{self.parse_media_endpoint}/video", files={"file": file_info}) - - async def parse_audio(self, filelike: Union[str, bytes, Path], bytes_filename: str = None) -> dict: - """解析音频""" - self.verify_file_ext(filelike, self.ALLOWED_AUDIO_EXTENSIONS, bytes_filename) - file_info = await self.get_file_info(filelike, bytes_filename, only_bytes=False) - return await self._request_parse(f"{self.parse_media_endpoint}/audio", files={"file": file_info}) - - async def parse_website(self, url: str) -> dict: - """ - 解析网站 - fixme:官方api还存在问题 - """ - return await self._request_parse(f"{self.parse_website_endpoint}/parse", params={"url": url}) diff --git a/tests/metagpt/rag/engines/test_simple.py b/tests/metagpt/rag/engines/test_simple.py index 61b9816a5..6bbe011c0 100644 --- a/tests/metagpt/rag/engines/test_simple.py +++ b/tests/metagpt/rag/engines/test_simple.py @@ -7,6 +7,7 @@ from llama_index.core.llms import MockLLM from llama_index.core.schema import Document, NodeWithScore, TextNode from metagpt.rag.engines import SimpleEngine +from metagpt.rag.parser import OmniParse from metagpt.rag.retrievers import SimpleHybridRetriever from metagpt.rag.retrievers.base import ModifiableRAGRetriever, PersistableRAGRetriever from metagpt.rag.schema import BM25RetrieverConfig, ObjectNode @@ -39,7 +40,7 @@ class TestSimpleEngine: @pytest.fixture def mock_get_file_extractor(self, mocker): - return mocker.patch("metagpt.rag.engines.simple.SimpleEngine.get_file_extractor") + return mocker.patch("metagpt.rag.engines.simple.SimpleEngine._get_file_extractor") def test_from_docs( self, @@ -307,3 +308,17 @@ class TestSimpleEngine: # Assert assert "obj" in node.node.metadata assert node.node.metadata["obj"] == expected_obj + + def test_get_file_extractor(self, mocker): + # mock no omniparse config + mock_omniparse_config = mocker.patch("metagpt.rag.engines.simple.config.omniparse", autospec=True) + mock_omniparse_config.base_url = "" + + file_extractor = SimpleEngine._get_file_extractor() + assert file_extractor == {} + + # mock have omniparse config + mock_omniparse_config.base_url = "http://localhost:8000" + file_extractor = SimpleEngine._get_file_extractor(file_type=".pdf") + assert ".pdf" in file_extractor + assert isinstance(file_extractor[".pdf"], OmniParse) diff --git a/tests/metagpt/rag/parser/test_omniparse.py b/tests/metagpt/rag/parser/test_omniparse.py index 79b173d5b..57b5329cf 100644 --- a/tests/metagpt/rag/parser/test_omniparse.py +++ b/tests/metagpt/rag/parser/test_omniparse.py @@ -1,32 +1,118 @@ import pytest +from llama_index.core import Document from metagpt.const import EXAMPLE_DATA_PATH -from metagpt.rag.parser.omniparse import OmniParseClient -from metagpt.rag.schema import OmniParsedResult +from metagpt.rag.parser import OmniParse +from metagpt.rag.schema import ( + OmniParsedResult, + OmniParseOptions, + OmniParseType, + ParseResultType, +) +from metagpt.utils.omniparse_client import OmniParseClient + +# test data +TEST_DOCX = EXAMPLE_DATA_PATH / "omniparse/test01.docx" +TEST_PDF = EXAMPLE_DATA_PATH / "omniparse/test02.pdf" +TEST_VIDEO = EXAMPLE_DATA_PATH / "omniparse/test03.mp4" +TEST_AUDIO = EXAMPLE_DATA_PATH / "omniparse/test04.mp3" class TestOmniParseClient: parse_client = OmniParseClient() - # test data - TEST_DOCX = EXAMPLE_DATA_PATH / "parse/test01.docx" - TEST_PDF = EXAMPLE_DATA_PATH / "parse/test02.pdf" - TEST_VIDEO = EXAMPLE_DATA_PATH / "parse/test03.mp4" - TEST_AUDIO = EXAMPLE_DATA_PATH / "parse/test04.mp3" - @pytest.fixture - def request_parse(self, mocker): + def mock_request_parse(self, mocker): return mocker.patch("metagpt.rag.parser.omniparse.OmniParseClient._request_parse") @pytest.mark.asyncio - async def test_parse_pdf(self, request_parse): + async def test_parse_pdf(self, mock_request_parse): mock_content = "#test title\ntest content" mock_parsed_ret = OmniParsedResult(text=mock_content, markdown=mock_content) - request_parse.return_value = mock_parsed_ret.model_dump() - parse_ret = await self.parse_client.parse_pdf(self.TEST_PDF) + mock_request_parse.return_value = mock_parsed_ret.model_dump() + parse_ret = await self.parse_client.parse_pdf(TEST_PDF) assert parse_ret == mock_parsed_ret + @pytest.mark.asyncio + async def test_parse_document(self, mock_request_parse): + mock_content = "#test title\ntest_parse_document" + mock_parsed_ret = OmniParsedResult(text=mock_content, markdown=mock_content) + mock_request_parse.return_value = mock_parsed_ret.model_dump() + + with open(TEST_DOCX, "rb") as f: + file_bytes = f.read() + + with pytest.raises(ValueError): + # bytes data must provide bytes_filename + await self.parse_client.parse_document(file_bytes) + + parse_ret = await self.parse_client.parse_document(file_bytes, bytes_filename="test.docx") + assert parse_ret == mock_parsed_ret + + @pytest.mark.asyncio + async def test_parse_video(self, mock_request_parse): + mock_content = "#test title\ntest_parse_video" + mock_request_parse.return_value = { + "text": mock_content, + "metadata": {}, + } + with pytest.raises(ValueError): + # Wrong file extension test + await self.parse_client.parse_video(TEST_DOCX) + + parse_ret = await self.parse_client.parse_video(TEST_VIDEO) + assert "text" in parse_ret and "metadata" in parse_ret + assert parse_ret["text"] == mock_content + + @pytest.mark.asyncio + async def test_parse_audio(self, mock_request_parse): + mock_content = "#test title\ntest_parse_audio" + mock_request_parse.return_value = { + "text": mock_content, + "metadata": {}, + } + parse_ret = await self.parse_client.parse_audio(TEST_AUDIO) + assert "text" in parse_ret and "metadata" in parse_ret + assert parse_ret["text"] == mock_content + class TestOmniParse: - def test_load_data(self): - pass + @pytest.fixture + def mock_omniparse(self): + parser = OmniParse( + parse_options=OmniParseOptions( + parse_type=OmniParseType.PDF, + result_type=ParseResultType.MD, + max_timeout=120, + num_workers=3, + ) + ) + return parser + + @pytest.fixture + def mock_request_parse(self, mocker): + return mocker.patch("metagpt.rag.parser.omniparse.OmniParseClient._request_parse") + + @pytest.mark.asyncio + async def test_load_data(self, mock_omniparse, mock_request_parse): + # mock + mock_content = "#test title\ntest content" + mock_parsed_ret = OmniParsedResult(text=mock_content, markdown=mock_content) + mock_request_parse.return_value = mock_parsed_ret.model_dump() + + # single file + documents = mock_omniparse.load_data(file_path=TEST_PDF) + doc = documents[0] + assert isinstance(doc, Document) + assert doc.text == mock_parsed_ret.text == mock_parsed_ret.markdown + + # multi files + file_paths = [TEST_DOCX, TEST_PDF] + mock_omniparse.parse_type = OmniParseType.DOCUMENT + documents = await mock_omniparse.aload_data(file_path=file_paths) + doc = documents[0] + + # assert + assert isinstance(doc, Document) + assert len(documents) == len(file_paths) + assert doc.text == mock_parsed_ret.text == mock_parsed_ret.markdown