diff --git a/examples/data/rag/test01.docx b/examples/data/parse/test01.docx similarity index 100% rename from examples/data/rag/test01.docx rename to examples/data/parse/test01.docx diff --git a/examples/data/rag/test02.pdf b/examples/data/parse/test02.pdf similarity index 100% rename from examples/data/rag/test02.pdf rename to examples/data/parse/test02.pdf diff --git a/examples/data/parse/test03.mp4 b/examples/data/parse/test03.mp4 new file mode 100644 index 000000000..2f1a222da Binary files /dev/null and b/examples/data/parse/test03.mp4 differ diff --git a/examples/data/parse/test04.mp3 b/examples/data/parse/test04.mp3 new file mode 100644 index 000000000..2c8e149d8 Binary files /dev/null and b/examples/data/parse/test04.mp3 differ diff --git a/examples/rag/omniparse_client.py b/examples/rag/omniparse_client.py index 61622ee36..a528e535f 100644 --- a/examples/rag/omniparse_client.py +++ b/examples/rag/omniparse_client.py @@ -7,30 +7,56 @@ from metagpt.logs import logger from metagpt.rag.parser.omniparse.client import OmniParseClient from metagpt.rag.parser.omniparse.parse import OmniParse from metagpt.rag.schema import OmniParseOptions, OmniParseType +from metagpt.const import EXAMPLE_DATA_PATH + +TEST_DOCX = EXAMPLE_DATA_PATH / "parse/test01.docx" +TEST_PDF = EXAMPLE_DATA_PATH / "parse/test02.pdf" +TEST_VIDEO = EXAMPLE_DATA_PATH / "parse/test03.mp4" +TEST_AUDIO = EXAMPLE_DATA_PATH / "parse/test04.mp3" +TEST_WEBSITE_URL = "https://github.com/geekan/MetaGPT" async def omniparse_client_example(): client = OmniParseClient(base_url=config.omniparse.base_url) - with open("../data/rag/test01.docx", "rb") as f: + # docx + with open(TEST_DOCX, "rb") as f: filelike = f.read() - parse_document_ret = await client.parse_document(filelike=filelike, bytes_filename="test_01.docx") - logger.info(parse_document_ret) + document_parse_ret = await client.parse_document(filelike=filelike, bytes_filename="test_01.docx") + logger.info(document_parse_ret) - parse_pdf_ret = await client.parse_pdf(filelike="../data/rag/test02.pdf") - logger.info(parse_pdf_ret) + # pdf + pdf_parse_ret = await client.parse_pdf(filelike=TEST_PDF) + logger.info(pdf_parse_ret) + + # video + video_parse_ret = await client.parse_video(filelike=TEST_VIDEO) + logger.info(video_parse_ret) + + # audio + audio_parse_ret = await client.parse_audio(filelike=TEST_AUDIO) + logger.info(audio_parse_ret) + + # website fixme:omniparse官方api还存在问题 + # website_parse_ret = await client.parse_website(url=TEST_WEBSITE_URL) + # logger.info(website_parse_ret) async def omniparse_example(): parser = OmniParse( api_key=config.omniparse.api_key, base_url=config.omniparse.base_url, - parse_options=OmniParseOptions(parse_type=OmniParseType.PDF, result_type=ResultType.MD) + parse_options=OmniParseOptions( + parse_type=OmniParseType.PDF, + result_type=ResultType.MD, + max_timeout=120, + num_workers=3, + ) ) - ret = await parser.aload_data(file_path="../data/rag/test02.pdf") + ret = parser.load_data(file_path=TEST_PDF) logger.info(ret) - file_paths = ["../data/rag/test01.docx", "../data/rag/test02.pdf"] + file_paths = [TEST_DOCX, TEST_PDF] parser.parse_type = OmniParseType.DOCUMENT ret = await parser.aload_data(file_path=file_paths) logger.info(ret) diff --git a/metagpt/rag/parser/omniparse/client.py b/metagpt/rag/parser/omniparse/client.py index 5d2c330ef..20cc70c4d 100644 --- a/metagpt/rag/parser/omniparse/client.py +++ b/metagpt/rag/parser/omniparse/client.py @@ -1,5 +1,7 @@ import mimetypes import os +from pathlib import Path + import aiofiles import httpx from typing import Union @@ -31,37 +33,72 @@ class OmniParseClient: self.parse_website_endpoint = "/parse_website" self.parse_document_endpoint = "/parse_document" - async def __request_parse(self, endpoint: str, files: dict = None, json: dict = None) -> dict: + async def __request_parse( + self, endpoint: str, method: str = "POST", + files: dict = None, params: dict = None, + data: dict = None, json: dict = None, + headers: dict = None, **kwargs, + ) -> dict: """ 请求api解析文档 Args: endpoint (str): API endpoint. files (dict, optional): 请求文件数据. + params (dict, optional): 查询字符串参数. + data (dict, optional): 请求体数据. json (dict, optional): 请求json数据. + headers (dict, optional): 请求头数据. + **kwargs: 其他 httpx.AsyncClient.request() 关键字参数 Returns: dict: 响应的json数据 """ url = f"{self.base_url}{endpoint}" - headers = {"Authorization": f"Bearer {self.api_key}"} if self.api_key else {} + method = method.upper() + headers = headers or {} + _headers = {"Authorization": f"Bearer {self.api_key}"} if self.api_key else {} + headers.update(**_headers) async with httpx.AsyncClient() as client: - response = await client.post(url, files=files, json=json, headers=headers, timeout=self.max_timeout) + response = await client.request( + url=url, method=method, + files=files, params=params, json=json, data=data, + headers=headers, timeout=self.max_timeout, **kwargs, + ) response.raise_for_status() return response.json() @staticmethod - def verify_file_ext(filelike: Union[str, bytes], allowed_file_extensions: set): - """校验文件后缀""" - if not filelike or isinstance(filelike, bytes): - return - file_ext = os.path.splitext(filelike)[1].lower() - if file_ext not in allowed_file_extensions: - raise ValueError("File extension must be one of {}".format(allowed_file_extensions)) + def verify_file_ext(filelike: Union[str, bytes, Path], allowed_file_extensions: set, bytes_filename: str = None): + """ + 校验文件后缀 + Args: + filelike: 文件路径 or 文件字节数据 + allowed_file_extensions: 允许的文件扩展名 + bytes_filename: 当字节数据时使用这个参数校验 + Raises: + ValueError + + Returns: + """ + verify_file_path = None + if isinstance(filelike, (str, Path)): + verify_file_path = str(filelike) + elif isinstance(filelike, bytes) and bytes_filename: + verify_file_path = bytes_filename + + if not verify_file_path: + # 仅传字节数据时不校验 + return + + file_ext = os.path.splitext(verify_file_path)[1].lower() + if file_ext not in allowed_file_extensions: + raise ValueError(f"Not allowed {file_ext} File extension must be one of {allowed_file_extensions}") + + @staticmethod async def get_file_info( - self, - filelike: Union[str, bytes], + filelike: Union[str, bytes, Path], bytes_filename: str = None, only_bytes=True, ) -> Union[bytes, tuple]: @@ -72,14 +109,17 @@ class OmniParseClient: bytes_filename: 通过字节数据上传需要指定文件名称,方便获取mime_type only_bytes: 是否只需要字节数据 + Raises: + ValueError + Notes: 由于 parse_document 支持多种文件解析,需要上传文件时指定文件的mime_type Returns: [bytes, tuple] """ - if isinstance(filelike, str): - filename = os.path.basename(filelike) + if isinstance(filelike, (str, Path)): + filename = os.path.basename(str(filelike)) async with aiofiles.open(filelike, 'rb') as file: file_bytes = await file.read() @@ -99,7 +139,7 @@ class OmniParseClient: else: raise ValueError("filelike must be a string (file path) or bytes.") - async def parse_document(self, filelike: Union[str, bytes], bytes_filename: str = None) -> OmniParsedResult: + async def parse_document(self, filelike: Union[str, bytes, Path], bytes_filename: str = None) -> OmniParsedResult: """ 解析文档类型数据(支持 ".pdf", ".ppt", ".pptx", ".doc", ".docs") Args: @@ -112,13 +152,13 @@ class OmniParseClient: Returns: OmniParsedResult """ - self.verify_file_ext(filelike, self.ALLOWED_DOCUMENT_EXTENSIONS) + self.verify_file_ext(filelike, self.ALLOWED_DOCUMENT_EXTENSIONS, bytes_filename) file_info = await self.get_file_info(filelike, bytes_filename, only_bytes=False) resp = await self.__request_parse(self.parse_document_endpoint, files={'file': file_info}) data = OmniParsedResult(**resp) return data - async def parse_pdf(self, filelike: Union[str, bytes]) -> OmniParsedResult: + async def parse_pdf(self, filelike: Union[str, bytes, Path]) -> OmniParsedResult: """ 解析PDF文档 Args: @@ -137,18 +177,21 @@ class OmniParseClient: data = OmniParsedResult(**resp) return data - async def parse_video(self, filelike: Union[str, bytes], bytes_filename: str = None) -> dict: + async def parse_video(self, filelike: Union[str, bytes, Path], bytes_filename: str = None) -> dict: """解析视频""" - self.verify_file_ext(filelike, self.ALLOWED_VIDEO_EXTENSIONS) + self.verify_file_ext(filelike, self.ALLOWED_VIDEO_EXTENSIONS, bytes_filename) file_info = await self.get_file_info(filelike, bytes_filename, only_bytes=False) return await self.__request_parse(f"{self.parse_media_endpoint}/video", files={'file': file_info}) - async def parse_audio(self, filelike: Union[str, bytes], bytes_filename: str = None) -> dict: + async def parse_audio(self, filelike: Union[str, bytes, Path], bytes_filename: str = None) -> dict: """解析音频""" - self.verify_file_ext(filelike, self.ALLOWED_AUDIO_EXTENSIONS) + self.verify_file_ext(filelike, self.ALLOWED_AUDIO_EXTENSIONS, bytes_filename) file_info = await self.get_file_info(filelike, bytes_filename, only_bytes=False) return await self.__request_parse(f"{self.parse_media_endpoint}/audio", files={'file': file_info}) async def parse_website(self, url: str) -> dict: - """解析网站""" - return await self.__request_parse(self.parse_website_endpoint, json={'url': url}) + """ + 解析网站 + fixme:官方api还存在问题 + """ + return await self.__request_parse(f"{self.parse_website_endpoint}/parse", params={'url': url}) diff --git a/metagpt/rag/parser/omniparse/parse.py b/metagpt/rag/parser/omniparse/parse.py index db61759f5..176620934 100644 --- a/metagpt/rag/parser/omniparse/parse.py +++ b/metagpt/rag/parser/omniparse/parse.py @@ -1,5 +1,6 @@ import asyncio from fileinput import FileInput +from pathlib import Path from typing import List, Union, Optional from llama_index.core import Document @@ -10,6 +11,7 @@ from llama_parse import ResultType from metagpt.rag.parser.omniparse.client import OmniParseClient from metagpt.rag.schema import OmniParseOptions, OmniParseType from metagpt.logs import logger +from metagpt.utils.async_helper import NestAsyncio class OmniParse(BaseReader): @@ -46,7 +48,7 @@ class OmniParse(BaseReader): async def _aload_data( self, - file_path: Union[str, bytes], + file_path: Union[str, bytes, Path], extra_info: Optional[dict] = None, ) -> List[Document]: try: @@ -78,7 +80,7 @@ class OmniParse(BaseReader): extra_info: Optional[dict] = None, ) -> List[Document]: docs = [] - if isinstance(file_path, (str, bytes)): + if isinstance(file_path, (str, bytes, Path)): # 处理单个 docs = await self._aload_data(file_path, extra_info) elif isinstance(file_path, list): @@ -94,4 +96,5 @@ class OmniParse(BaseReader): extra_info: Optional[dict] = None, ) -> List[Document]: """Load data from the input path.""" + NestAsyncio.apply_once() # 兼容异步嵌套调用 return asyncio.run(self.aload_data(file_path, extra_info)) diff --git a/metagpt/rag/schema.py b/metagpt/rag/schema.py index 7ef191d0c..24c597196 100644 --- a/metagpt/rag/schema.py +++ b/metagpt/rag/schema.py @@ -229,7 +229,7 @@ class OmniParseOptions(BaseModel): parse_type: OmniParseType = Field(default=OmniParseType.DOCUMENT, description="OmniParse解析类型,默认文档类型") max_timeout: Optional[int] = Field(default=120, description="OmniParse服务请求最大超时") num_workers: int = Field( - default=4, + default=5, gt=0, lt=10, description="多文件列表时并发请求数量",