代码优化

This commit is contained in:
liuminhui 2024-07-22 18:38:44 +08:00
parent 758acf8ba6
commit f9d3a8c521
9 changed files with 57 additions and 65 deletions

View file

@ -3,7 +3,7 @@ import asyncio
from metagpt.config2 import config
from metagpt.const import EXAMPLE_DATA_PATH
from metagpt.logs import logger
from metagpt.rag.parser import OmniParse
from metagpt.rag.parsers import OmniParse
from metagpt.rag.schema import OmniParseOptions, OmniParseType, ParseResultType
from metagpt.utils.omniparse_client import OmniParseClient
@ -19,20 +19,20 @@ async def omniparse_client_example():
# docx
with open(TEST_DOCX, "rb") as f:
filelike = f.read()
document_parse_ret = await client.parse_document(filelike=filelike, bytes_filename="test_01.docx")
file_input = f.read()
document_parse_ret = await client.parse_document(file_input=file_input, bytes_filename="test_01.docx")
logger.info(document_parse_ret)
# pdf
pdf_parse_ret = await client.parse_pdf(filelike=TEST_PDF)
pdf_parse_ret = await client.parse_pdf(file_input=TEST_PDF)
logger.info(pdf_parse_ret)
# video
video_parse_ret = await client.parse_video(filelike=TEST_VIDEO)
video_parse_ret = await client.parse_video(file_input=TEST_VIDEO)
logger.info(video_parse_ret)
# audio
audio_parse_ret = await client.parse_audio(filelike=TEST_AUDIO)
audio_parse_ret = await client.parse_audio(file_input=TEST_AUDIO)
logger.info(audio_parse_ret)

View file

@ -38,7 +38,7 @@ from metagpt.rag.factories import (
get_retriever,
)
from metagpt.rag.interface import NoEmbedding, RAGObject
from metagpt.rag.parser import OmniParse
from metagpt.rag.parsers import OmniParse
from metagpt.rag.retrievers.base import ModifiableRAGRetriever, PersistableRAGRetriever
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
from metagpt.rag.schema import (
@ -106,7 +106,7 @@ class SimpleEngine(RetrieverQueryEngine):
if not input_dir and not input_files:
raise ValueError("Must provide either `input_dir` or `input_files`.")
file_extractor = cls._get_file_extractor(file_type=".pdf")
file_extractor = cls._get_file_extractor()
documents = SimpleDirectoryReader(
input_dir=input_dir, input_files=input_files, file_extractor=file_extractor
).load_data()
@ -312,29 +312,21 @@ class SimpleEngine(RetrieverQueryEngine):
return [SentenceSplitter()]
@staticmethod
def _get_file_extractor(file_type: str = None) -> dict[str:BaseReader]:
def _get_file_extractor() -> dict[str:BaseReader]:
"""
Get the file extractor for a specified file type.
If no file type is provided, return all available extractors.
Currently, only OmniParse PDF extraction is supported.
Args:
file_type: The type of file for which the extractor is needed. Defaults to None.
Get the file extractor.
Currently, only PDF use OmniParse
Returns:
dict[file_type: BaseReader]
"""
file_extractor_mapping: dict[str:BaseReader] = {}
file_extractor: dict[str:BaseReader] = {}
if config.omniparse.base_url:
pdf_parser = OmniParse(
api_key=config.omniparse.api_key,
base_url=config.omniparse.base_url,
parse_options=OmniParseOptions(parse_type=OmniParseType.PDF, result_type=ParseResultType.MD),
)
file_extractor_mapping[".pdf"] = pdf_parser
file_extractor[".pdf"] = pdf_parser
if file_type:
file_extractor = file_extractor_mapping.get(file_type)
return {file_type: file_extractor} if file_extractor else {}
return file_extractor_mapping
return file_extractor

View file

@ -1,3 +0,0 @@
from metagpt.rag.parser.omniparse import OmniParse
__all__ = ["OmniParse"]

View file

@ -0,0 +1,3 @@
from metagpt.rag.parsers.omniparse import OmniParse
__all__ = ["OmniParse"]

View file

@ -232,7 +232,7 @@ class ParseResultType(str, Enum):
class OmniParseOptions(BaseModel):
"""OmniParse可选配置"""
"""OmniParse Options config"""
result_type: ParseResultType = Field(default=ParseResultType.MD, description="OmniParse result_type")
parse_type: OmniParseType = Field(default=OmniParseType.DOCUMENT, description="OmniParse parse_type")

View file

@ -12,7 +12,7 @@ from metagpt.utils.common import aread_bin
class OmniParseClient:
"""
OmniParse Server Client
This client interacts with the OmniParse server to parse different types of media, documents, and websites.
This client interacts with the OmniParse server to parse different types of media, documents.
OmniParse API Documentation: https://docs.cognitivelab.in/api
@ -88,12 +88,12 @@ class OmniParseClient:
response.raise_for_status()
return response.json()
async def parse_document(self, filelike: Union[str, bytes, Path], bytes_filename: str = None) -> OmniParsedResult:
async def parse_document(self, file_input: Union[str, bytes, Path], bytes_filename: str = None) -> OmniParsedResult:
"""
Parse document-type data (supports ".pdf", ".ppt", ".pptx", ".doc", ".docx").
Args:
filelike: File path or file byte data.
file_input: File path or file byte data.
bytes_filename: Filename for byte data, useful for determining MIME type for the HTTP request.
Raises:
@ -102,18 +102,18 @@ class OmniParseClient:
Returns:
OmniParsedResult: The result of the document parsing.
"""
self.verify_file_ext(filelike, self.ALLOWED_DOCUMENT_EXTENSIONS, bytes_filename)
file_info = await self.get_file_info(filelike, bytes_filename)
self.verify_file_ext(file_input, self.ALLOWED_DOCUMENT_EXTENSIONS, bytes_filename)
file_info = await self.get_file_info(file_input, bytes_filename)
resp = await self._request_parse(self.parse_document_endpoint, files={"file": file_info})
data = OmniParsedResult(**resp)
return data
async def parse_pdf(self, filelike: Union[str, bytes, Path]) -> OmniParsedResult:
async def parse_pdf(self, file_input: Union[str, bytes, Path]) -> OmniParsedResult:
"""
Parse pdf document.
Args:
filelike: File path or file byte data.
file_input: File path or file byte data.
Raises:
ValueError: If the file extension is not allowed.
@ -121,19 +121,19 @@ class OmniParseClient:
Returns:
OmniParsedResult: The result of the pdf parsing.
"""
self.verify_file_ext(filelike, {".pdf"})
file_info = await self.get_file_info(filelike, only_bytes=True)
self.verify_file_ext(file_input, {".pdf"})
file_info = await self.get_file_info(file_input, only_bytes=True)
endpoint = f"{self.parse_document_endpoint}/pdf"
resp = await self._request_parse(endpoint=endpoint, files={"file": file_info})
data = OmniParsedResult(**resp)
return data
async def parse_video(self, filelike: Union[str, bytes, Path], bytes_filename: str = None) -> dict:
async def parse_video(self, file_input: Union[str, bytes, Path], bytes_filename: str = None) -> dict:
"""
Parse video-type data (supports ".mp4", ".mkv", ".avi", ".mov").
Args:
filelike: File path or file byte data.
file_input: File path or file byte data.
bytes_filename: Filename for byte data, useful for determining MIME type for the HTTP request.
Raises:
@ -142,16 +142,16 @@ class OmniParseClient:
Returns:
dict: JSON response data.
"""
self.verify_file_ext(filelike, self.ALLOWED_VIDEO_EXTENSIONS, bytes_filename)
file_info = await self.get_file_info(filelike, bytes_filename)
self.verify_file_ext(file_input, self.ALLOWED_VIDEO_EXTENSIONS, bytes_filename)
file_info = await self.get_file_info(file_input, bytes_filename)
return await self._request_parse(f"{self.parse_media_endpoint}/video", files={"file": file_info})
async def parse_audio(self, filelike: Union[str, bytes, Path], bytes_filename: str = None) -> dict:
async def parse_audio(self, file_input: Union[str, bytes, Path], bytes_filename: str = None) -> dict:
"""
Parse audio-type data (supports ".mp3", ".wav", ".aac").
Args:
filelike: File path or file byte data.
file_input: File path or file byte data.
bytes_filename: Filename for byte data, useful for determining MIME type for the HTTP request.
Raises:
@ -160,19 +160,19 @@ class OmniParseClient:
Returns:
dict: JSON response data.
"""
self.verify_file_ext(filelike, self.ALLOWED_AUDIO_EXTENSIONS, bytes_filename)
file_info = await self.get_file_info(filelike, bytes_filename)
self.verify_file_ext(file_input, self.ALLOWED_AUDIO_EXTENSIONS, bytes_filename)
file_info = await self.get_file_info(file_input, bytes_filename)
return await self._request_parse(f"{self.parse_media_endpoint}/audio", files={"file": file_info})
@staticmethod
def verify_file_ext(filelike: Union[str, bytes, Path], allowed_file_extensions: set, bytes_filename: str = None):
def verify_file_ext(file_input: Union[str, bytes, Path], allowed_file_extensions: set, bytes_filename: str = None):
"""
Verify the file extension.
Args:
filelike: File path or file byte data.
file_input: File path or file byte data.
allowed_file_extensions: Set of allowed file extensions.
bytes_filename: Filename to use for verification when `filelike` is byte data.
bytes_filename: Filename to use for verification when `file_input` is byte data.
Raises:
ValueError: If the file extension is not allowed.
@ -180,9 +180,9 @@ class OmniParseClient:
Returns:
"""
verify_file_path = None
if isinstance(filelike, (str, Path)):
verify_file_path = str(filelike)
elif isinstance(filelike, bytes) and bytes_filename:
if isinstance(file_input, (str, Path)):
verify_file_path = str(file_input)
elif isinstance(file_input, bytes) and bytes_filename:
verify_file_path = bytes_filename
if not verify_file_path:
@ -195,7 +195,7 @@ class OmniParseClient:
@staticmethod
async def get_file_info(
filelike: Union[str, bytes, Path],
file_input: Union[str, bytes, Path],
bytes_filename: str = None,
only_bytes: bool = False,
) -> Union[bytes, tuple]:
@ -203,12 +203,12 @@ class OmniParseClient:
Get file information.
Args:
filelike: File path or file byte data.
file_input: File path or file byte data.
bytes_filename: Filename to use when uploading byte data, useful for determining MIME type.
only_bytes: Whether to return only byte data. Default is False, which returns a tuple.
Raises:
ValueError: If bytes_filename is not provided when filelike is bytes or if filelike is not a valid type.
ValueError: If bytes_filename is not provided when file_input is bytes or if file_input is not a valid type.
Notes:
Since `parse_document`,`parse_video`, `parse_audio` supports parsing various file types,
@ -217,22 +217,22 @@ class OmniParseClient:
Returns: [bytes, tuple]
Returns bytes if only_bytes is True, otherwise returns a tuple (filename, file_bytes, mime_type).
"""
if isinstance(filelike, (str, Path)):
filename = os.path.basename(str(filelike))
file_bytes = await aread_bin(filelike)
if isinstance(file_input, (str, Path)):
filename = os.path.basename(str(file_input))
file_bytes = await aread_bin(file_input)
if only_bytes:
return file_bytes
mime_type = mimetypes.guess_type(filelike)[0]
mime_type = mimetypes.guess_type(file_input)[0]
return filename, file_bytes, mime_type
elif isinstance(filelike, bytes):
elif isinstance(file_input, bytes):
if only_bytes:
return filelike
return file_input
if not bytes_filename:
raise ValueError("bytes_filename must be set when passing bytes")
mime_type = mimetypes.guess_type(bytes_filename)[0]
return bytes_filename, filelike, mime_type
return bytes_filename, file_input, mime_type
else:
raise ValueError("filelike must be a string (file path) or bytes.")
raise ValueError("file_input must be a string (file path) or bytes.")

View file

@ -7,7 +7,7 @@ from llama_index.core.llms import MockLLM
from llama_index.core.schema import Document, NodeWithScore, TextNode
from metagpt.rag.engines import SimpleEngine
from metagpt.rag.parser import OmniParse
from metagpt.rag.parsers import OmniParse
from metagpt.rag.retrievers import SimpleHybridRetriever
from metagpt.rag.retrievers.base import ModifiableRAGRetriever, PersistableRAGRetriever
from metagpt.rag.schema import BM25RetrieverConfig, ObjectNode
@ -319,6 +319,6 @@ class TestSimpleEngine:
# mock have omniparse config
mock_omniparse_config.base_url = "http://localhost:8000"
file_extractor = SimpleEngine._get_file_extractor(file_type=".pdf")
file_extractor = SimpleEngine._get_file_extractor()
assert ".pdf" in file_extractor
assert isinstance(file_extractor[".pdf"], OmniParse)

View file

@ -2,7 +2,7 @@ import pytest
from llama_index.core import Document
from metagpt.const import EXAMPLE_DATA_PATH
from metagpt.rag.parser import OmniParse
from metagpt.rag.parsers import OmniParse
from metagpt.rag.schema import (
OmniParsedResult,
OmniParseOptions,
@ -23,7 +23,7 @@ class TestOmniParseClient:
@pytest.fixture
def mock_request_parse(self, mocker):
return mocker.patch("metagpt.rag.parser.omniparse.OmniParseClient._request_parse")
return mocker.patch("metagpt.rag.parsers.omniparse.OmniParseClient._request_parse")
@pytest.mark.asyncio
async def test_parse_pdf(self, mock_request_parse):
@ -91,7 +91,7 @@ class TestOmniParse:
@pytest.fixture
def mock_request_parse(self, mocker):
return mocker.patch("metagpt.rag.parser.omniparse.OmniParseClient._request_parse")
return mocker.patch("metagpt.rag.parsers.omniparse.OmniParseClient._request_parse")
@pytest.mark.asyncio
async def test_load_data(self, mock_omniparse, mock_request_parse):