diff --git a/examples/rag/omniparse.py b/examples/rag/omniparse.py index 7cea8f954..af8207c5a 100644 --- a/examples/rag/omniparse.py +++ b/examples/rag/omniparse.py @@ -3,7 +3,7 @@ import asyncio from metagpt.config2 import config from metagpt.const import EXAMPLE_DATA_PATH from metagpt.logs import logger -from metagpt.rag.parser import OmniParse +from metagpt.rag.parsers import OmniParse from metagpt.rag.schema import OmniParseOptions, OmniParseType, ParseResultType from metagpt.utils.omniparse_client import OmniParseClient @@ -19,20 +19,20 @@ async def omniparse_client_example(): # docx with open(TEST_DOCX, "rb") as f: - filelike = f.read() - document_parse_ret = await client.parse_document(filelike=filelike, bytes_filename="test_01.docx") + file_input = f.read() + document_parse_ret = await client.parse_document(file_input=file_input, bytes_filename="test_01.docx") logger.info(document_parse_ret) # pdf - pdf_parse_ret = await client.parse_pdf(filelike=TEST_PDF) + pdf_parse_ret = await client.parse_pdf(file_input=TEST_PDF) logger.info(pdf_parse_ret) # video - video_parse_ret = await client.parse_video(filelike=TEST_VIDEO) + video_parse_ret = await client.parse_video(file_input=TEST_VIDEO) logger.info(video_parse_ret) # audio - audio_parse_ret = await client.parse_audio(filelike=TEST_AUDIO) + audio_parse_ret = await client.parse_audio(file_input=TEST_AUDIO) logger.info(audio_parse_ret) diff --git a/metagpt/rag/engines/simple.py b/metagpt/rag/engines/simple.py index f3f86111d..e015b7b7f 100644 --- a/metagpt/rag/engines/simple.py +++ b/metagpt/rag/engines/simple.py @@ -38,7 +38,7 @@ from metagpt.rag.factories import ( get_retriever, ) from metagpt.rag.interface import NoEmbedding, RAGObject -from metagpt.rag.parser import OmniParse +from metagpt.rag.parsers import OmniParse from metagpt.rag.retrievers.base import ModifiableRAGRetriever, PersistableRAGRetriever from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever from metagpt.rag.schema import ( @@ -106,7 +106,7 @@ class SimpleEngine(RetrieverQueryEngine): if not input_dir and not input_files: raise ValueError("Must provide either `input_dir` or `input_files`.") - file_extractor = cls._get_file_extractor(file_type=".pdf") + file_extractor = cls._get_file_extractor() documents = SimpleDirectoryReader( input_dir=input_dir, input_files=input_files, file_extractor=file_extractor ).load_data() @@ -312,29 +312,21 @@ class SimpleEngine(RetrieverQueryEngine): return [SentenceSplitter()] @staticmethod - def _get_file_extractor(file_type: str = None) -> dict[str:BaseReader]: + def _get_file_extractor() -> dict[str:BaseReader]: """ - Get the file extractor for a specified file type. - If no file type is provided, return all available extractors. - Currently, only OmniParse PDF extraction is supported. - - Args: - file_type: The type of file for which the extractor is needed. Defaults to None. + Get the file extractor. + Currently, only PDF use OmniParse Returns: dict[file_type: BaseReader] """ - file_extractor_mapping: dict[str:BaseReader] = {} + file_extractor: dict[str:BaseReader] = {} if config.omniparse.base_url: pdf_parser = OmniParse( api_key=config.omniparse.api_key, base_url=config.omniparse.base_url, parse_options=OmniParseOptions(parse_type=OmniParseType.PDF, result_type=ParseResultType.MD), ) - file_extractor_mapping[".pdf"] = pdf_parser + file_extractor[".pdf"] = pdf_parser - if file_type: - file_extractor = file_extractor_mapping.get(file_type) - return {file_type: file_extractor} if file_extractor else {} - - return file_extractor_mapping + return file_extractor diff --git a/metagpt/rag/parser/__init__.py b/metagpt/rag/parser/__init__.py deleted file mode 100644 index 90d7eb971..000000000 --- a/metagpt/rag/parser/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from metagpt.rag.parser.omniparse import OmniParse - -__all__ = ["OmniParse"] diff --git a/metagpt/rag/parsers/__init__.py b/metagpt/rag/parsers/__init__.py new file mode 100644 index 000000000..03ac0de3a --- /dev/null +++ b/metagpt/rag/parsers/__init__.py @@ -0,0 +1,3 @@ +from metagpt.rag.parsers.omniparse import OmniParse + +__all__ = ["OmniParse"] diff --git a/metagpt/rag/parser/omniparse.py b/metagpt/rag/parsers/omniparse.py similarity index 100% rename from metagpt/rag/parser/omniparse.py rename to metagpt/rag/parsers/omniparse.py diff --git a/metagpt/rag/schema.py b/metagpt/rag/schema.py index b02ef16cc..a8a10f90e 100644 --- a/metagpt/rag/schema.py +++ b/metagpt/rag/schema.py @@ -232,7 +232,7 @@ class ParseResultType(str, Enum): class OmniParseOptions(BaseModel): - """OmniParse可选配置""" + """OmniParse Options config""" result_type: ParseResultType = Field(default=ParseResultType.MD, description="OmniParse result_type") parse_type: OmniParseType = Field(default=OmniParseType.DOCUMENT, description="OmniParse parse_type") diff --git a/metagpt/utils/omniparse_client.py b/metagpt/utils/omniparse_client.py index 787d6fe88..12c5ac392 100644 --- a/metagpt/utils/omniparse_client.py +++ b/metagpt/utils/omniparse_client.py @@ -12,7 +12,7 @@ from metagpt.utils.common import aread_bin class OmniParseClient: """ OmniParse Server Client - This client interacts with the OmniParse server to parse different types of media, documents, and websites. + This client interacts with the OmniParse server to parse different types of media, documents. OmniParse API Documentation: https://docs.cognitivelab.in/api @@ -88,12 +88,12 @@ class OmniParseClient: response.raise_for_status() return response.json() - async def parse_document(self, filelike: Union[str, bytes, Path], bytes_filename: str = None) -> OmniParsedResult: + async def parse_document(self, file_input: Union[str, bytes, Path], bytes_filename: str = None) -> OmniParsedResult: """ Parse document-type data (supports ".pdf", ".ppt", ".pptx", ".doc", ".docx"). Args: - filelike: File path or file byte data. + file_input: File path or file byte data. bytes_filename: Filename for byte data, useful for determining MIME type for the HTTP request. Raises: @@ -102,18 +102,18 @@ class OmniParseClient: Returns: OmniParsedResult: The result of the document parsing. """ - self.verify_file_ext(filelike, self.ALLOWED_DOCUMENT_EXTENSIONS, bytes_filename) - file_info = await self.get_file_info(filelike, bytes_filename) + self.verify_file_ext(file_input, self.ALLOWED_DOCUMENT_EXTENSIONS, bytes_filename) + file_info = await self.get_file_info(file_input, bytes_filename) resp = await self._request_parse(self.parse_document_endpoint, files={"file": file_info}) data = OmniParsedResult(**resp) return data - async def parse_pdf(self, filelike: Union[str, bytes, Path]) -> OmniParsedResult: + async def parse_pdf(self, file_input: Union[str, bytes, Path]) -> OmniParsedResult: """ Parse pdf document. Args: - filelike: File path or file byte data. + file_input: File path or file byte data. Raises: ValueError: If the file extension is not allowed. @@ -121,19 +121,19 @@ class OmniParseClient: Returns: OmniParsedResult: The result of the pdf parsing. """ - self.verify_file_ext(filelike, {".pdf"}) - file_info = await self.get_file_info(filelike, only_bytes=True) + self.verify_file_ext(file_input, {".pdf"}) + file_info = await self.get_file_info(file_input, only_bytes=True) endpoint = f"{self.parse_document_endpoint}/pdf" resp = await self._request_parse(endpoint=endpoint, files={"file": file_info}) data = OmniParsedResult(**resp) return data - async def parse_video(self, filelike: Union[str, bytes, Path], bytes_filename: str = None) -> dict: + async def parse_video(self, file_input: Union[str, bytes, Path], bytes_filename: str = None) -> dict: """ Parse video-type data (supports ".mp4", ".mkv", ".avi", ".mov"). Args: - filelike: File path or file byte data. + file_input: File path or file byte data. bytes_filename: Filename for byte data, useful for determining MIME type for the HTTP request. Raises: @@ -142,16 +142,16 @@ class OmniParseClient: Returns: dict: JSON response data. """ - self.verify_file_ext(filelike, self.ALLOWED_VIDEO_EXTENSIONS, bytes_filename) - file_info = await self.get_file_info(filelike, bytes_filename) + self.verify_file_ext(file_input, self.ALLOWED_VIDEO_EXTENSIONS, bytes_filename) + file_info = await self.get_file_info(file_input, bytes_filename) return await self._request_parse(f"{self.parse_media_endpoint}/video", files={"file": file_info}) - async def parse_audio(self, filelike: Union[str, bytes, Path], bytes_filename: str = None) -> dict: + async def parse_audio(self, file_input: Union[str, bytes, Path], bytes_filename: str = None) -> dict: """ Parse audio-type data (supports ".mp3", ".wav", ".aac"). Args: - filelike: File path or file byte data. + file_input: File path or file byte data. bytes_filename: Filename for byte data, useful for determining MIME type for the HTTP request. Raises: @@ -160,19 +160,19 @@ class OmniParseClient: Returns: dict: JSON response data. """ - self.verify_file_ext(filelike, self.ALLOWED_AUDIO_EXTENSIONS, bytes_filename) - file_info = await self.get_file_info(filelike, bytes_filename) + self.verify_file_ext(file_input, self.ALLOWED_AUDIO_EXTENSIONS, bytes_filename) + file_info = await self.get_file_info(file_input, bytes_filename) return await self._request_parse(f"{self.parse_media_endpoint}/audio", files={"file": file_info}) @staticmethod - def verify_file_ext(filelike: Union[str, bytes, Path], allowed_file_extensions: set, bytes_filename: str = None): + def verify_file_ext(file_input: Union[str, bytes, Path], allowed_file_extensions: set, bytes_filename: str = None): """ Verify the file extension. Args: - filelike: File path or file byte data. + file_input: File path or file byte data. allowed_file_extensions: Set of allowed file extensions. - bytes_filename: Filename to use for verification when `filelike` is byte data. + bytes_filename: Filename to use for verification when `file_input` is byte data. Raises: ValueError: If the file extension is not allowed. @@ -180,9 +180,9 @@ class OmniParseClient: Returns: """ verify_file_path = None - if isinstance(filelike, (str, Path)): - verify_file_path = str(filelike) - elif isinstance(filelike, bytes) and bytes_filename: + if isinstance(file_input, (str, Path)): + verify_file_path = str(file_input) + elif isinstance(file_input, bytes) and bytes_filename: verify_file_path = bytes_filename if not verify_file_path: @@ -195,7 +195,7 @@ class OmniParseClient: @staticmethod async def get_file_info( - filelike: Union[str, bytes, Path], + file_input: Union[str, bytes, Path], bytes_filename: str = None, only_bytes: bool = False, ) -> Union[bytes, tuple]: @@ -203,12 +203,12 @@ class OmniParseClient: Get file information. Args: - filelike: File path or file byte data. + file_input: File path or file byte data. bytes_filename: Filename to use when uploading byte data, useful for determining MIME type. only_bytes: Whether to return only byte data. Default is False, which returns a tuple. Raises: - ValueError: If bytes_filename is not provided when filelike is bytes or if filelike is not a valid type. + ValueError: If bytes_filename is not provided when file_input is bytes or if file_input is not a valid type. Notes: Since `parse_document`,`parse_video`, `parse_audio` supports parsing various file types, @@ -217,22 +217,22 @@ class OmniParseClient: Returns: [bytes, tuple] Returns bytes if only_bytes is True, otherwise returns a tuple (filename, file_bytes, mime_type). """ - if isinstance(filelike, (str, Path)): - filename = os.path.basename(str(filelike)) - file_bytes = await aread_bin(filelike) + if isinstance(file_input, (str, Path)): + filename = os.path.basename(str(file_input)) + file_bytes = await aread_bin(file_input) if only_bytes: return file_bytes - mime_type = mimetypes.guess_type(filelike)[0] + mime_type = mimetypes.guess_type(file_input)[0] return filename, file_bytes, mime_type - elif isinstance(filelike, bytes): + elif isinstance(file_input, bytes): if only_bytes: - return filelike + return file_input if not bytes_filename: raise ValueError("bytes_filename must be set when passing bytes") mime_type = mimetypes.guess_type(bytes_filename)[0] - return bytes_filename, filelike, mime_type + return bytes_filename, file_input, mime_type else: - raise ValueError("filelike must be a string (file path) or bytes.") + raise ValueError("file_input must be a string (file path) or bytes.") diff --git a/tests/metagpt/rag/engines/test_simple.py b/tests/metagpt/rag/engines/test_simple.py index 6bbe011c0..a10fcbe63 100644 --- a/tests/metagpt/rag/engines/test_simple.py +++ b/tests/metagpt/rag/engines/test_simple.py @@ -7,7 +7,7 @@ from llama_index.core.llms import MockLLM from llama_index.core.schema import Document, NodeWithScore, TextNode from metagpt.rag.engines import SimpleEngine -from metagpt.rag.parser import OmniParse +from metagpt.rag.parsers import OmniParse from metagpt.rag.retrievers import SimpleHybridRetriever from metagpt.rag.retrievers.base import ModifiableRAGRetriever, PersistableRAGRetriever from metagpt.rag.schema import BM25RetrieverConfig, ObjectNode @@ -319,6 +319,6 @@ class TestSimpleEngine: # mock have omniparse config mock_omniparse_config.base_url = "http://localhost:8000" - file_extractor = SimpleEngine._get_file_extractor(file_type=".pdf") + file_extractor = SimpleEngine._get_file_extractor() assert ".pdf" in file_extractor assert isinstance(file_extractor[".pdf"], OmniParse) diff --git a/tests/metagpt/rag/parser/test_omniparse.py b/tests/metagpt/rag/parser/test_omniparse.py index 57b5329cf..d2b533d06 100644 --- a/tests/metagpt/rag/parser/test_omniparse.py +++ b/tests/metagpt/rag/parser/test_omniparse.py @@ -2,7 +2,7 @@ import pytest from llama_index.core import Document from metagpt.const import EXAMPLE_DATA_PATH -from metagpt.rag.parser import OmniParse +from metagpt.rag.parsers import OmniParse from metagpt.rag.schema import ( OmniParsedResult, OmniParseOptions, @@ -23,7 +23,7 @@ class TestOmniParseClient: @pytest.fixture def mock_request_parse(self, mocker): - return mocker.patch("metagpt.rag.parser.omniparse.OmniParseClient._request_parse") + return mocker.patch("metagpt.rag.parsers.omniparse.OmniParseClient._request_parse") @pytest.mark.asyncio async def test_parse_pdf(self, mock_request_parse): @@ -91,7 +91,7 @@ class TestOmniParse: @pytest.fixture def mock_request_parse(self, mocker): - return mocker.patch("metagpt.rag.parser.omniparse.OmniParseClient._request_parse") + return mocker.patch("metagpt.rag.parsers.omniparse.OmniParseClient._request_parse") @pytest.mark.asyncio async def test_load_data(self, mock_omniparse, mock_request_parse):