mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-04-30 19:36:24 +02:00
代码优化
This commit is contained in:
parent
758acf8ba6
commit
f9d3a8c521
9 changed files with 57 additions and 65 deletions
|
|
@ -3,7 +3,7 @@ import asyncio
|
|||
from metagpt.config2 import config
|
||||
from metagpt.const import EXAMPLE_DATA_PATH
|
||||
from metagpt.logs import logger
|
||||
from metagpt.rag.parser import OmniParse
|
||||
from metagpt.rag.parsers import OmniParse
|
||||
from metagpt.rag.schema import OmniParseOptions, OmniParseType, ParseResultType
|
||||
from metagpt.utils.omniparse_client import OmniParseClient
|
||||
|
||||
|
|
@ -19,20 +19,20 @@ async def omniparse_client_example():
|
|||
|
||||
# docx
|
||||
with open(TEST_DOCX, "rb") as f:
|
||||
filelike = f.read()
|
||||
document_parse_ret = await client.parse_document(filelike=filelike, bytes_filename="test_01.docx")
|
||||
file_input = f.read()
|
||||
document_parse_ret = await client.parse_document(file_input=file_input, bytes_filename="test_01.docx")
|
||||
logger.info(document_parse_ret)
|
||||
|
||||
# pdf
|
||||
pdf_parse_ret = await client.parse_pdf(filelike=TEST_PDF)
|
||||
pdf_parse_ret = await client.parse_pdf(file_input=TEST_PDF)
|
||||
logger.info(pdf_parse_ret)
|
||||
|
||||
# video
|
||||
video_parse_ret = await client.parse_video(filelike=TEST_VIDEO)
|
||||
video_parse_ret = await client.parse_video(file_input=TEST_VIDEO)
|
||||
logger.info(video_parse_ret)
|
||||
|
||||
# audio
|
||||
audio_parse_ret = await client.parse_audio(filelike=TEST_AUDIO)
|
||||
audio_parse_ret = await client.parse_audio(file_input=TEST_AUDIO)
|
||||
logger.info(audio_parse_ret)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -38,7 +38,7 @@ from metagpt.rag.factories import (
|
|||
get_retriever,
|
||||
)
|
||||
from metagpt.rag.interface import NoEmbedding, RAGObject
|
||||
from metagpt.rag.parser import OmniParse
|
||||
from metagpt.rag.parsers import OmniParse
|
||||
from metagpt.rag.retrievers.base import ModifiableRAGRetriever, PersistableRAGRetriever
|
||||
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
|
||||
from metagpt.rag.schema import (
|
||||
|
|
@ -106,7 +106,7 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
if not input_dir and not input_files:
|
||||
raise ValueError("Must provide either `input_dir` or `input_files`.")
|
||||
|
||||
file_extractor = cls._get_file_extractor(file_type=".pdf")
|
||||
file_extractor = cls._get_file_extractor()
|
||||
documents = SimpleDirectoryReader(
|
||||
input_dir=input_dir, input_files=input_files, file_extractor=file_extractor
|
||||
).load_data()
|
||||
|
|
@ -312,29 +312,21 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
return [SentenceSplitter()]
|
||||
|
||||
@staticmethod
|
||||
def _get_file_extractor(file_type: str = None) -> dict[str:BaseReader]:
|
||||
def _get_file_extractor() -> dict[str:BaseReader]:
|
||||
"""
|
||||
Get the file extractor for a specified file type.
|
||||
If no file type is provided, return all available extractors.
|
||||
Currently, only OmniParse PDF extraction is supported.
|
||||
|
||||
Args:
|
||||
file_type: The type of file for which the extractor is needed. Defaults to None.
|
||||
Get the file extractor.
|
||||
Currently, only PDF use OmniParse
|
||||
|
||||
Returns:
|
||||
dict[file_type: BaseReader]
|
||||
"""
|
||||
file_extractor_mapping: dict[str:BaseReader] = {}
|
||||
file_extractor: dict[str:BaseReader] = {}
|
||||
if config.omniparse.base_url:
|
||||
pdf_parser = OmniParse(
|
||||
api_key=config.omniparse.api_key,
|
||||
base_url=config.omniparse.base_url,
|
||||
parse_options=OmniParseOptions(parse_type=OmniParseType.PDF, result_type=ParseResultType.MD),
|
||||
)
|
||||
file_extractor_mapping[".pdf"] = pdf_parser
|
||||
file_extractor[".pdf"] = pdf_parser
|
||||
|
||||
if file_type:
|
||||
file_extractor = file_extractor_mapping.get(file_type)
|
||||
return {file_type: file_extractor} if file_extractor else {}
|
||||
|
||||
return file_extractor_mapping
|
||||
return file_extractor
|
||||
|
|
|
|||
|
|
@ -1,3 +0,0 @@
|
|||
from metagpt.rag.parser.omniparse import OmniParse
|
||||
|
||||
__all__ = ["OmniParse"]
|
||||
3
metagpt/rag/parsers/__init__.py
Normal file
3
metagpt/rag/parsers/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
from metagpt.rag.parsers.omniparse import OmniParse
|
||||
|
||||
__all__ = ["OmniParse"]
|
||||
|
|
@ -232,7 +232,7 @@ class ParseResultType(str, Enum):
|
|||
|
||||
|
||||
class OmniParseOptions(BaseModel):
|
||||
"""OmniParse可选配置"""
|
||||
"""OmniParse Options config"""
|
||||
|
||||
result_type: ParseResultType = Field(default=ParseResultType.MD, description="OmniParse result_type")
|
||||
parse_type: OmniParseType = Field(default=OmniParseType.DOCUMENT, description="OmniParse parse_type")
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from metagpt.utils.common import aread_bin
|
|||
class OmniParseClient:
|
||||
"""
|
||||
OmniParse Server Client
|
||||
This client interacts with the OmniParse server to parse different types of media, documents, and websites.
|
||||
This client interacts with the OmniParse server to parse different types of media, documents.
|
||||
|
||||
OmniParse API Documentation: https://docs.cognitivelab.in/api
|
||||
|
||||
|
|
@ -88,12 +88,12 @@ class OmniParseClient:
|
|||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
async def parse_document(self, filelike: Union[str, bytes, Path], bytes_filename: str = None) -> OmniParsedResult:
|
||||
async def parse_document(self, file_input: Union[str, bytes, Path], bytes_filename: str = None) -> OmniParsedResult:
|
||||
"""
|
||||
Parse document-type data (supports ".pdf", ".ppt", ".pptx", ".doc", ".docx").
|
||||
|
||||
Args:
|
||||
filelike: File path or file byte data.
|
||||
file_input: File path or file byte data.
|
||||
bytes_filename: Filename for byte data, useful for determining MIME type for the HTTP request.
|
||||
|
||||
Raises:
|
||||
|
|
@ -102,18 +102,18 @@ class OmniParseClient:
|
|||
Returns:
|
||||
OmniParsedResult: The result of the document parsing.
|
||||
"""
|
||||
self.verify_file_ext(filelike, self.ALLOWED_DOCUMENT_EXTENSIONS, bytes_filename)
|
||||
file_info = await self.get_file_info(filelike, bytes_filename)
|
||||
self.verify_file_ext(file_input, self.ALLOWED_DOCUMENT_EXTENSIONS, bytes_filename)
|
||||
file_info = await self.get_file_info(file_input, bytes_filename)
|
||||
resp = await self._request_parse(self.parse_document_endpoint, files={"file": file_info})
|
||||
data = OmniParsedResult(**resp)
|
||||
return data
|
||||
|
||||
async def parse_pdf(self, filelike: Union[str, bytes, Path]) -> OmniParsedResult:
|
||||
async def parse_pdf(self, file_input: Union[str, bytes, Path]) -> OmniParsedResult:
|
||||
"""
|
||||
Parse pdf document.
|
||||
|
||||
Args:
|
||||
filelike: File path or file byte data.
|
||||
file_input: File path or file byte data.
|
||||
|
||||
Raises:
|
||||
ValueError: If the file extension is not allowed.
|
||||
|
|
@ -121,19 +121,19 @@ class OmniParseClient:
|
|||
Returns:
|
||||
OmniParsedResult: The result of the pdf parsing.
|
||||
"""
|
||||
self.verify_file_ext(filelike, {".pdf"})
|
||||
file_info = await self.get_file_info(filelike, only_bytes=True)
|
||||
self.verify_file_ext(file_input, {".pdf"})
|
||||
file_info = await self.get_file_info(file_input, only_bytes=True)
|
||||
endpoint = f"{self.parse_document_endpoint}/pdf"
|
||||
resp = await self._request_parse(endpoint=endpoint, files={"file": file_info})
|
||||
data = OmniParsedResult(**resp)
|
||||
return data
|
||||
|
||||
async def parse_video(self, filelike: Union[str, bytes, Path], bytes_filename: str = None) -> dict:
|
||||
async def parse_video(self, file_input: Union[str, bytes, Path], bytes_filename: str = None) -> dict:
|
||||
"""
|
||||
Parse video-type data (supports ".mp4", ".mkv", ".avi", ".mov").
|
||||
|
||||
Args:
|
||||
filelike: File path or file byte data.
|
||||
file_input: File path or file byte data.
|
||||
bytes_filename: Filename for byte data, useful for determining MIME type for the HTTP request.
|
||||
|
||||
Raises:
|
||||
|
|
@ -142,16 +142,16 @@ class OmniParseClient:
|
|||
Returns:
|
||||
dict: JSON response data.
|
||||
"""
|
||||
self.verify_file_ext(filelike, self.ALLOWED_VIDEO_EXTENSIONS, bytes_filename)
|
||||
file_info = await self.get_file_info(filelike, bytes_filename)
|
||||
self.verify_file_ext(file_input, self.ALLOWED_VIDEO_EXTENSIONS, bytes_filename)
|
||||
file_info = await self.get_file_info(file_input, bytes_filename)
|
||||
return await self._request_parse(f"{self.parse_media_endpoint}/video", files={"file": file_info})
|
||||
|
||||
async def parse_audio(self, filelike: Union[str, bytes, Path], bytes_filename: str = None) -> dict:
|
||||
async def parse_audio(self, file_input: Union[str, bytes, Path], bytes_filename: str = None) -> dict:
|
||||
"""
|
||||
Parse audio-type data (supports ".mp3", ".wav", ".aac").
|
||||
|
||||
Args:
|
||||
filelike: File path or file byte data.
|
||||
file_input: File path or file byte data.
|
||||
bytes_filename: Filename for byte data, useful for determining MIME type for the HTTP request.
|
||||
|
||||
Raises:
|
||||
|
|
@ -160,19 +160,19 @@ class OmniParseClient:
|
|||
Returns:
|
||||
dict: JSON response data.
|
||||
"""
|
||||
self.verify_file_ext(filelike, self.ALLOWED_AUDIO_EXTENSIONS, bytes_filename)
|
||||
file_info = await self.get_file_info(filelike, bytes_filename)
|
||||
self.verify_file_ext(file_input, self.ALLOWED_AUDIO_EXTENSIONS, bytes_filename)
|
||||
file_info = await self.get_file_info(file_input, bytes_filename)
|
||||
return await self._request_parse(f"{self.parse_media_endpoint}/audio", files={"file": file_info})
|
||||
|
||||
@staticmethod
|
||||
def verify_file_ext(filelike: Union[str, bytes, Path], allowed_file_extensions: set, bytes_filename: str = None):
|
||||
def verify_file_ext(file_input: Union[str, bytes, Path], allowed_file_extensions: set, bytes_filename: str = None):
|
||||
"""
|
||||
Verify the file extension.
|
||||
|
||||
Args:
|
||||
filelike: File path or file byte data.
|
||||
file_input: File path or file byte data.
|
||||
allowed_file_extensions: Set of allowed file extensions.
|
||||
bytes_filename: Filename to use for verification when `filelike` is byte data.
|
||||
bytes_filename: Filename to use for verification when `file_input` is byte data.
|
||||
|
||||
Raises:
|
||||
ValueError: If the file extension is not allowed.
|
||||
|
|
@ -180,9 +180,9 @@ class OmniParseClient:
|
|||
Returns:
|
||||
"""
|
||||
verify_file_path = None
|
||||
if isinstance(filelike, (str, Path)):
|
||||
verify_file_path = str(filelike)
|
||||
elif isinstance(filelike, bytes) and bytes_filename:
|
||||
if isinstance(file_input, (str, Path)):
|
||||
verify_file_path = str(file_input)
|
||||
elif isinstance(file_input, bytes) and bytes_filename:
|
||||
verify_file_path = bytes_filename
|
||||
|
||||
if not verify_file_path:
|
||||
|
|
@ -195,7 +195,7 @@ class OmniParseClient:
|
|||
|
||||
@staticmethod
|
||||
async def get_file_info(
|
||||
filelike: Union[str, bytes, Path],
|
||||
file_input: Union[str, bytes, Path],
|
||||
bytes_filename: str = None,
|
||||
only_bytes: bool = False,
|
||||
) -> Union[bytes, tuple]:
|
||||
|
|
@ -203,12 +203,12 @@ class OmniParseClient:
|
|||
Get file information.
|
||||
|
||||
Args:
|
||||
filelike: File path or file byte data.
|
||||
file_input: File path or file byte data.
|
||||
bytes_filename: Filename to use when uploading byte data, useful for determining MIME type.
|
||||
only_bytes: Whether to return only byte data. Default is False, which returns a tuple.
|
||||
|
||||
Raises:
|
||||
ValueError: If bytes_filename is not provided when filelike is bytes or if filelike is not a valid type.
|
||||
ValueError: If bytes_filename is not provided when file_input is bytes or if file_input is not a valid type.
|
||||
|
||||
Notes:
|
||||
Since `parse_document`,`parse_video`, `parse_audio` supports parsing various file types,
|
||||
|
|
@ -217,22 +217,22 @@ class OmniParseClient:
|
|||
Returns: [bytes, tuple]
|
||||
Returns bytes if only_bytes is True, otherwise returns a tuple (filename, file_bytes, mime_type).
|
||||
"""
|
||||
if isinstance(filelike, (str, Path)):
|
||||
filename = os.path.basename(str(filelike))
|
||||
file_bytes = await aread_bin(filelike)
|
||||
if isinstance(file_input, (str, Path)):
|
||||
filename = os.path.basename(str(file_input))
|
||||
file_bytes = await aread_bin(file_input)
|
||||
|
||||
if only_bytes:
|
||||
return file_bytes
|
||||
|
||||
mime_type = mimetypes.guess_type(filelike)[0]
|
||||
mime_type = mimetypes.guess_type(file_input)[0]
|
||||
return filename, file_bytes, mime_type
|
||||
elif isinstance(filelike, bytes):
|
||||
elif isinstance(file_input, bytes):
|
||||
if only_bytes:
|
||||
return filelike
|
||||
return file_input
|
||||
if not bytes_filename:
|
||||
raise ValueError("bytes_filename must be set when passing bytes")
|
||||
|
||||
mime_type = mimetypes.guess_type(bytes_filename)[0]
|
||||
return bytes_filename, filelike, mime_type
|
||||
return bytes_filename, file_input, mime_type
|
||||
else:
|
||||
raise ValueError("filelike must be a string (file path) or bytes.")
|
||||
raise ValueError("file_input must be a string (file path) or bytes.")
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from llama_index.core.llms import MockLLM
|
|||
from llama_index.core.schema import Document, NodeWithScore, TextNode
|
||||
|
||||
from metagpt.rag.engines import SimpleEngine
|
||||
from metagpt.rag.parser import OmniParse
|
||||
from metagpt.rag.parsers import OmniParse
|
||||
from metagpt.rag.retrievers import SimpleHybridRetriever
|
||||
from metagpt.rag.retrievers.base import ModifiableRAGRetriever, PersistableRAGRetriever
|
||||
from metagpt.rag.schema import BM25RetrieverConfig, ObjectNode
|
||||
|
|
@ -319,6 +319,6 @@ class TestSimpleEngine:
|
|||
|
||||
# mock have omniparse config
|
||||
mock_omniparse_config.base_url = "http://localhost:8000"
|
||||
file_extractor = SimpleEngine._get_file_extractor(file_type=".pdf")
|
||||
file_extractor = SimpleEngine._get_file_extractor()
|
||||
assert ".pdf" in file_extractor
|
||||
assert isinstance(file_extractor[".pdf"], OmniParse)
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import pytest
|
|||
from llama_index.core import Document
|
||||
|
||||
from metagpt.const import EXAMPLE_DATA_PATH
|
||||
from metagpt.rag.parser import OmniParse
|
||||
from metagpt.rag.parsers import OmniParse
|
||||
from metagpt.rag.schema import (
|
||||
OmniParsedResult,
|
||||
OmniParseOptions,
|
||||
|
|
@ -23,7 +23,7 @@ class TestOmniParseClient:
|
|||
|
||||
@pytest.fixture
|
||||
def mock_request_parse(self, mocker):
|
||||
return mocker.patch("metagpt.rag.parser.omniparse.OmniParseClient._request_parse")
|
||||
return mocker.patch("metagpt.rag.parsers.omniparse.OmniParseClient._request_parse")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_pdf(self, mock_request_parse):
|
||||
|
|
@ -91,7 +91,7 @@ class TestOmniParse:
|
|||
|
||||
@pytest.fixture
|
||||
def mock_request_parse(self, mocker):
|
||||
return mocker.patch("metagpt.rag.parser.omniparse.OmniParseClient._request_parse")
|
||||
return mocker.patch("metagpt.rag.parsers.omniparse.OmniParseClient._request_parse")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_data(self, mock_omniparse, mock_request_parse):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue