mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-04-25 00:36:55 +02:00
cr修改,单测完善
This commit is contained in:
parent
79334de5a4
commit
758acf8ba6
18 changed files with 371 additions and 178 deletions
|
|
@ -59,3 +59,7 @@ iflytek_api_key: "YOUR_API_KEY"
|
|||
iflytek_api_secret: "YOUR_API_SECRET"
|
||||
|
||||
metagpt_tti_url: "YOUR_MODEL_URL"
|
||||
|
||||
omniparse:
|
||||
api_key: "YOUR_API_KEY"
|
||||
base_url: "YOUR_BASE_URL"
|
||||
|
|
|
|||
|
|
@ -5,8 +5,4 @@ llm:
|
|||
api_type: "openai" # or azure / ollama / groq etc.
|
||||
model: "gpt-4-turbo" # or gpt-3.5-turbo
|
||||
base_url: "https://api.openai.com/v1" # or forward url / other llm url
|
||||
api_key: "xxxx"
|
||||
|
||||
omniparse:
|
||||
api_key: "your_api_key"
|
||||
base_url: "http://192.168.50.126:8000"
|
||||
api_key: "YOUR_API_KEY"
|
||||
|
|
@ -1,18 +1,16 @@
|
|||
import asyncio
|
||||
|
||||
from llama_parse import ResultType
|
||||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.logs import logger
|
||||
from metagpt.rag.parser.omniparse.client import OmniParseClient
|
||||
from metagpt.rag.parser.omniparse.parse import OmniParse
|
||||
from metagpt.rag.schema import OmniParseOptions, OmniParseType
|
||||
from metagpt.const import EXAMPLE_DATA_PATH
|
||||
from metagpt.logs import logger
|
||||
from metagpt.rag.parser import OmniParse
|
||||
from metagpt.rag.schema import OmniParseOptions, OmniParseType, ParseResultType
|
||||
from metagpt.utils.omniparse_client import OmniParseClient
|
||||
|
||||
TEST_DOCX = EXAMPLE_DATA_PATH / "parse/test01.docx"
|
||||
TEST_PDF = EXAMPLE_DATA_PATH / "parse/test02.pdf"
|
||||
TEST_VIDEO = EXAMPLE_DATA_PATH / "parse/test03.mp4"
|
||||
TEST_AUDIO = EXAMPLE_DATA_PATH / "parse/test04.mp3"
|
||||
TEST_DOCX = EXAMPLE_DATA_PATH / "omniparse/test01.docx"
|
||||
TEST_PDF = EXAMPLE_DATA_PATH / "omniparse/test02.pdf"
|
||||
TEST_VIDEO = EXAMPLE_DATA_PATH / "omniparse/test03.mp4"
|
||||
TEST_AUDIO = EXAMPLE_DATA_PATH / "omniparse/test04.mp3"
|
||||
TEST_WEBSITE_URL = "https://github.com/geekan/MetaGPT"
|
||||
|
||||
|
||||
|
|
@ -37,10 +35,6 @@ async def omniparse_client_example():
|
|||
audio_parse_ret = await client.parse_audio(filelike=TEST_AUDIO)
|
||||
logger.info(audio_parse_ret)
|
||||
|
||||
# website fixme:omniparse官方api还存在问题
|
||||
# website_parse_ret = await client.parse_website(url=TEST_WEBSITE_URL)
|
||||
# logger.info(website_parse_ret)
|
||||
|
||||
|
||||
async def omniparse_example():
|
||||
parser = OmniParse(
|
||||
|
|
@ -48,10 +42,10 @@ async def omniparse_example():
|
|||
base_url=config.omniparse.base_url,
|
||||
parse_options=OmniParseOptions(
|
||||
parse_type=OmniParseType.PDF,
|
||||
result_type=ResultType.MD,
|
||||
result_type=ParseResultType.MD,
|
||||
max_timeout=120,
|
||||
num_workers=3,
|
||||
)
|
||||
),
|
||||
)
|
||||
ret = parser.load_data(file_path=TEST_PDF)
|
||||
logger.info(ret)
|
||||
|
|
@ -67,5 +61,5 @@ async def main():
|
|||
await omniparse_example()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
|
@ -12,7 +12,8 @@ from typing import Dict, Iterable, List, Literal, Optional
|
|||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from metagpt.configs.browser_config import BrowserConfig
|
||||
from metagpt.configs.embedding_config import EmbeddingConfig, OmniParseConfig
|
||||
from metagpt.configs.embedding_config import EmbeddingConfig
|
||||
from metagpt.configs.file_parser_config import OmniParseConfig
|
||||
from metagpt.configs.llm_config import LLMConfig, LLMType
|
||||
from metagpt.configs.mermaid_config import MermaidConfig
|
||||
from metagpt.configs.redis_config import RedisConfig
|
||||
|
|
|
|||
|
|
@ -52,8 +52,3 @@ class EmbeddingConfig(YamlModel):
|
|||
if v == "":
|
||||
return None
|
||||
return v
|
||||
|
||||
|
||||
class OmniParseConfig(YamlModel):
|
||||
api_key: str = ""
|
||||
base_url: str = ""
|
||||
|
|
|
|||
6
metagpt/configs/file_parser_config.py
Normal file
6
metagpt/configs/file_parser_config.py
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
from metagpt.utils.yaml_model import YamlModel
|
||||
|
||||
|
||||
class OmniParseConfig(YamlModel):
|
||||
api_key: str = ""
|
||||
base_url: str = ""
|
||||
|
|
@ -14,6 +14,7 @@ from llama_index.core.llms import LLM
|
|||
from llama_index.core.node_parser import SentenceSplitter
|
||||
from llama_index.core.postprocessor.types import BaseNodePostprocessor
|
||||
from llama_index.core.query_engine import RetrieverQueryEngine
|
||||
from llama_index.core.readers.base import BaseReader
|
||||
from llama_index.core.response_synthesizers import (
|
||||
BaseSynthesizer,
|
||||
get_response_synthesizer,
|
||||
|
|
@ -27,7 +28,6 @@ from llama_index.core.schema import (
|
|||
QueryType,
|
||||
TransformComponent,
|
||||
)
|
||||
from llama_parse import ResultType
|
||||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.rag.factories import (
|
||||
|
|
@ -38,7 +38,7 @@ from metagpt.rag.factories import (
|
|||
get_retriever,
|
||||
)
|
||||
from metagpt.rag.interface import NoEmbedding, RAGObject
|
||||
from metagpt.rag.parser.omniparse.parse import OmniParse
|
||||
from metagpt.rag.parser import OmniParse
|
||||
from metagpt.rag.retrievers.base import ModifiableRAGRetriever, PersistableRAGRetriever
|
||||
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
|
||||
from metagpt.rag.schema import (
|
||||
|
|
@ -46,7 +46,10 @@ from metagpt.rag.schema import (
|
|||
BaseRankerConfig,
|
||||
BaseRetrieverConfig,
|
||||
BM25RetrieverConfig,
|
||||
ObjectNode, OmniParseOptions, OmniParseType,
|
||||
ObjectNode,
|
||||
OmniParseOptions,
|
||||
OmniParseType,
|
||||
ParseResultType,
|
||||
)
|
||||
from metagpt.utils.common import import_class
|
||||
|
||||
|
|
@ -76,18 +79,6 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
)
|
||||
self._transformations = transformations or self._default_transformations()
|
||||
|
||||
@classmethod
|
||||
def get_file_extractor(cls, file_type: str):
|
||||
if not config.omniparse.base_url:
|
||||
return
|
||||
parser = OmniParse(
|
||||
api_key=config.omniparse.api_key,
|
||||
base_url=config.omniparse.base_url,
|
||||
parse_options=OmniParseOptions(parse_type=OmniParseType.PDF, result_type=ResultType.MD)
|
||||
)
|
||||
file_extractor = {file_type: parser}
|
||||
return file_extractor
|
||||
|
||||
@classmethod
|
||||
def from_docs(
|
||||
cls,
|
||||
|
|
@ -115,7 +106,7 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
if not input_dir and not input_files:
|
||||
raise ValueError("Must provide either `input_dir` or `input_files`.")
|
||||
|
||||
file_extractor = cls.get_file_extractor(file_type=".pdf")
|
||||
file_extractor = cls._get_file_extractor(file_type=".pdf")
|
||||
documents = SimpleDirectoryReader(
|
||||
input_dir=input_dir, input_files=input_files, file_extractor=file_extractor
|
||||
).load_data()
|
||||
|
|
@ -319,3 +310,31 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
@staticmethod
|
||||
def _default_transformations():
|
||||
return [SentenceSplitter()]
|
||||
|
||||
@staticmethod
|
||||
def _get_file_extractor(file_type: str = None) -> dict[str:BaseReader]:
|
||||
"""
|
||||
Get the file extractor for a specified file type.
|
||||
If no file type is provided, return all available extractors.
|
||||
Currently, only OmniParse PDF extraction is supported.
|
||||
|
||||
Args:
|
||||
file_type: The type of file for which the extractor is needed. Defaults to None.
|
||||
|
||||
Returns:
|
||||
dict[file_type: BaseReader]
|
||||
"""
|
||||
file_extractor_mapping: dict[str:BaseReader] = {}
|
||||
if config.omniparse.base_url:
|
||||
pdf_parser = OmniParse(
|
||||
api_key=config.omniparse.api_key,
|
||||
base_url=config.omniparse.base_url,
|
||||
parse_options=OmniParseOptions(parse_type=OmniParseType.PDF, result_type=ParseResultType.MD),
|
||||
)
|
||||
file_extractor_mapping[".pdf"] = pdf_parser
|
||||
|
||||
if file_type:
|
||||
file_extractor = file_extractor_mapping.get(file_type)
|
||||
return {file_type: file_extractor} if file_extractor else {}
|
||||
|
||||
return file_extractor_mapping
|
||||
|
|
|
|||
|
|
@ -0,0 +1,3 @@
|
|||
from metagpt.rag.parser.omniparse import OmniParse
|
||||
|
||||
__all__ = ["OmniParse"]
|
||||
|
|
@ -1,28 +1,31 @@
|
|||
import asyncio
|
||||
from fileinput import FileInput
|
||||
from pathlib import Path
|
||||
from typing import List, Union, Optional
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from llama_index.core import Document
|
||||
from llama_index.core.async_utils import run_jobs
|
||||
from llama_index.core.readers.base import BaseReader
|
||||
from llama_parse import ResultType
|
||||
|
||||
from metagpt.rag.parser.omniparse.client import OmniParseClient
|
||||
from metagpt.rag.schema import OmniParseOptions, OmniParseType
|
||||
from metagpt.logs import logger
|
||||
from metagpt.rag.schema import OmniParseOptions, OmniParseType
|
||||
from metagpt.utils.async_helper import NestAsyncio
|
||||
from metagpt.utils.omniparse_client import OmniParseClient
|
||||
|
||||
|
||||
class OmniParse(BaseReader):
|
||||
"""OmniParse"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key=None,
|
||||
base_url="http://localhost:8000",
|
||||
parse_options: OmniParseOptions = None
|
||||
self, api_key: str = None, base_url: str = "http://localhost:8000", parse_options: OmniParseOptions = None
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
api_key: Default None, can be used for authentication later.
|
||||
base_url: OmniParse Base URL for the API.
|
||||
parse_options: Optional settings for OmniParse. Default is OmniParseOptions with default values.
|
||||
"""
|
||||
self.parse_options = parse_options or OmniParseOptions()
|
||||
self.omniparse_client = OmniParseClient(api_key, base_url, max_timeout=self.parse_options.max_timeout)
|
||||
|
||||
|
|
@ -47,20 +50,32 @@ class OmniParse(BaseReader):
|
|||
self.parse_options.result_type = result_type
|
||||
|
||||
async def _aload_data(
|
||||
self,
|
||||
file_path: Union[str, bytes, Path],
|
||||
extra_info: Optional[dict] = None,
|
||||
self,
|
||||
file_path: Union[str, bytes, Path],
|
||||
extra_info: Optional[dict] = None,
|
||||
) -> List[Document]:
|
||||
"""
|
||||
Load data from the input file_path.
|
||||
|
||||
Args:
|
||||
file_path: File path or file byte data.
|
||||
extra_info: Optional dictionary containing additional information.
|
||||
|
||||
Returns:
|
||||
List[Document]
|
||||
"""
|
||||
try:
|
||||
if self.parse_type == OmniParseType.PDF:
|
||||
# 目前先只支持pdf解析
|
||||
# pdf parse
|
||||
parsed_result = await self.omniparse_client.parse_pdf(file_path)
|
||||
else:
|
||||
# other parse use omniparse_client.parse_document
|
||||
# For compatible byte data, additional filename is required
|
||||
extra_info = extra_info or {}
|
||||
filename = extra_info.get("filename") # 兼容字节数据要额外传filename
|
||||
filename = extra_info.get("filename")
|
||||
parsed_result = await self.omniparse_client.parse_document(file_path, bytes_filename=filename)
|
||||
|
||||
# 获取指定的结构数据
|
||||
# Get the specified structured data based on result_type
|
||||
content = getattr(parsed_result, self.result_type)
|
||||
docs = [
|
||||
Document(
|
||||
|
|
@ -75,26 +90,51 @@ class OmniParse(BaseReader):
|
|||
return docs
|
||||
|
||||
async def aload_data(
|
||||
self,
|
||||
file_path: Union[List[FileInput], FileInput],
|
||||
extra_info: Optional[dict] = None,
|
||||
self,
|
||||
file_path: Union[List[FileInput], FileInput],
|
||||
extra_info: Optional[dict] = None,
|
||||
) -> List[Document]:
|
||||
"""
|
||||
Load data from the input file_path.
|
||||
|
||||
Args:
|
||||
file_path: File path or file byte data.
|
||||
extra_info: Optional dictionary containing additional information.
|
||||
|
||||
Notes:
|
||||
This method ultimately calls _aload_data for processing.
|
||||
|
||||
Returns:
|
||||
List[Document]
|
||||
"""
|
||||
docs = []
|
||||
if isinstance(file_path, (str, bytes, Path)):
|
||||
# 处理单个
|
||||
# Processing single file
|
||||
docs = await self._aload_data(file_path, extra_info)
|
||||
elif isinstance(file_path, list):
|
||||
# 并发处理多个
|
||||
# Concurrently process multiple files
|
||||
parse_jobs = [self._aload_data(file_item, extra_info) for file_item in file_path]
|
||||
doc_ret_list = await run_jobs(jobs=parse_jobs, workers=self.parse_options.num_workers)
|
||||
docs = [doc for docs in doc_ret_list for doc in docs]
|
||||
return docs
|
||||
|
||||
def load_data(
|
||||
self,
|
||||
file_path: Union[List[FileInput], FileInput],
|
||||
extra_info: Optional[dict] = None,
|
||||
self,
|
||||
file_path: Union[List[FileInput], FileInput],
|
||||
extra_info: Optional[dict] = None,
|
||||
) -> List[Document]:
|
||||
"""Load data from the input path."""
|
||||
NestAsyncio.apply_once() # 兼容异步嵌套调用
|
||||
"""
|
||||
Load data from the input file_path.
|
||||
|
||||
Args:
|
||||
file_path: File path or file byte data.
|
||||
extra_info: Optional dictionary containing additional information.
|
||||
|
||||
Notes:
|
||||
This method ultimately calls aload_data for processing.
|
||||
|
||||
Returns:
|
||||
List[Document]
|
||||
"""
|
||||
NestAsyncio.apply_once() # Ensure compatibility with nested async calls
|
||||
return asyncio.run(self.aload_data(file_path, extra_info))
|
||||
|
|
@ -1,2 +0,0 @@
|
|||
from .client import OmniParseClient
|
||||
from .parse import OmniParse
|
||||
|
|
@ -8,7 +8,6 @@ from llama_index.core.embeddings import BaseEmbedding
|
|||
from llama_index.core.indices.base import BaseIndex
|
||||
from llama_index.core.schema import TextNode
|
||||
from llama_index.core.vector_stores.types import VectorStoreQueryMode
|
||||
from llama_parse import ResultType
|
||||
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator
|
||||
|
||||
from metagpt.config2 import config
|
||||
|
|
@ -218,23 +217,31 @@ class ObjectNode(TextNode):
|
|||
|
||||
|
||||
class OmniParseType(str, Enum):
|
||||
"""OmniParse解析类型"""
|
||||
"""OmniParseType"""
|
||||
|
||||
PDF = "PDF"
|
||||
DOCUMENT = "DOCUMENT"
|
||||
|
||||
|
||||
class ParseResultType(str, Enum):
|
||||
"""The result type for the parser."""
|
||||
|
||||
TXT = "text"
|
||||
MD = "markdown"
|
||||
JSON = "json"
|
||||
|
||||
|
||||
class OmniParseOptions(BaseModel):
|
||||
"""OmniParse可选配置"""
|
||||
|
||||
result_type: ResultType = Field(default=ResultType.MD, description="OmniParse解析返回的结果类型")
|
||||
parse_type: OmniParseType = Field(default=OmniParseType.DOCUMENT, description="OmniParse解析类型,默认文档类型")
|
||||
max_timeout: Optional[int] = Field(default=120, description="OmniParse服务请求最大超时")
|
||||
result_type: ParseResultType = Field(default=ParseResultType.MD, description="OmniParse result_type")
|
||||
parse_type: OmniParseType = Field(default=OmniParseType.DOCUMENT, description="OmniParse parse_type")
|
||||
max_timeout: Optional[int] = Field(default=120, description="Maximum timeout for OmniParse service requests")
|
||||
num_workers: int = Field(
|
||||
default=5,
|
||||
gt=0,
|
||||
lt=10,
|
||||
description="多文件列表时并发请求数量",
|
||||
description="Number of concurrent requests for multiple files",
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -3,28 +3,35 @@ import os
|
|||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import aiofiles
|
||||
import httpx
|
||||
|
||||
from metagpt.rag.schema import OmniParsedResult
|
||||
from metagpt.utils.common import aread_bin
|
||||
|
||||
|
||||
class OmniParseClient:
|
||||
"""
|
||||
OmniParse Server Client
|
||||
OmniParse API Docs: https://docs.cognitivelab.in/api
|
||||
This client interacts with the OmniParse server to parse different types of media, documents, and websites.
|
||||
|
||||
OmniParse API Documentation: https://docs.cognitivelab.in/api
|
||||
|
||||
Attributes:
|
||||
ALLOWED_DOCUMENT_EXTENSIONS (set): A set of supported document file extensions.
|
||||
ALLOWED_AUDIO_EXTENSIONS (set): A set of supported audio file extensions.
|
||||
ALLOWED_VIDEO_EXTENSIONS (set): A set of supported video file extensions.
|
||||
"""
|
||||
|
||||
ALLOWED_DOCUMENT_EXTENSIONS = {".pdf", ".ppt", ".pptx", ".doc", ".docx"}
|
||||
ALLOWED_AUDIO_EXTENSIONS = {".mp3", ".wav", ".aac"}
|
||||
ALLOWED_VIDEO_EXTENSIONS = {".mp4", ".mkv", ".avi", ".mov"}
|
||||
|
||||
def __init__(self, api_key=None, base_url="http://localhost:8000", max_timeout=120):
|
||||
def __init__(self, api_key: str = None, base_url: str = "http://localhost:8000", max_timeout: int = 120):
|
||||
"""
|
||||
Args:
|
||||
api_key: 默认 None 后续可用于鉴权
|
||||
base_url: api 基础url
|
||||
max_timeout: 请求最大超时时间单位s
|
||||
api_key: Default None, can be used for authentication later.
|
||||
base_url: Base URL for the API.
|
||||
max_timeout: Maximum request timeout in seconds.
|
||||
"""
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url
|
||||
|
|
@ -46,19 +53,20 @@ class OmniParseClient:
|
|||
**kwargs,
|
||||
) -> dict:
|
||||
"""
|
||||
请求api解析文档
|
||||
Request OmniParse API to parse a document.
|
||||
|
||||
Args:
|
||||
endpoint (str): API endpoint.
|
||||
files (dict, optional): 请求文件数据.
|
||||
params (dict, optional): 查询字符串参数.
|
||||
data (dict, optional): 请求体数据.
|
||||
json (dict, optional): 请求json数据.
|
||||
headers (dict, optional): 请求头数据.
|
||||
**kwargs: 其他 httpx.AsyncClient.request() 关键字参数
|
||||
method (str, optional): HTTP method to use. Default is "POST".
|
||||
files (dict, optional): Files to include in the request.
|
||||
params (dict, optional): Query string parameters.
|
||||
data (dict, optional): Form data to include in the request body.
|
||||
json (dict, optional): JSON data to include in the request body.
|
||||
headers (dict, optional): HTTP headers to include in the request.
|
||||
**kwargs: Additional keyword arguments for httpx.AsyncClient.request()
|
||||
|
||||
Returns:
|
||||
dict: 响应的json数据
|
||||
dict: JSON response data.
|
||||
"""
|
||||
url = f"{self.base_url}{endpoint}"
|
||||
method = method.upper()
|
||||
|
|
@ -80,17 +88,94 @@ class OmniParseClient:
|
|||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
async def parse_document(self, filelike: Union[str, bytes, Path], bytes_filename: str = None) -> OmniParsedResult:
|
||||
"""
|
||||
Parse document-type data (supports ".pdf", ".ppt", ".pptx", ".doc", ".docx").
|
||||
|
||||
Args:
|
||||
filelike: File path or file byte data.
|
||||
bytes_filename: Filename for byte data, useful for determining MIME type for the HTTP request.
|
||||
|
||||
Raises:
|
||||
ValueError: If the file extension is not allowed.
|
||||
|
||||
Returns:
|
||||
OmniParsedResult: The result of the document parsing.
|
||||
"""
|
||||
self.verify_file_ext(filelike, self.ALLOWED_DOCUMENT_EXTENSIONS, bytes_filename)
|
||||
file_info = await self.get_file_info(filelike, bytes_filename)
|
||||
resp = await self._request_parse(self.parse_document_endpoint, files={"file": file_info})
|
||||
data = OmniParsedResult(**resp)
|
||||
return data
|
||||
|
||||
async def parse_pdf(self, filelike: Union[str, bytes, Path]) -> OmniParsedResult:
|
||||
"""
|
||||
Parse pdf document.
|
||||
|
||||
Args:
|
||||
filelike: File path or file byte data.
|
||||
|
||||
Raises:
|
||||
ValueError: If the file extension is not allowed.
|
||||
|
||||
Returns:
|
||||
OmniParsedResult: The result of the pdf parsing.
|
||||
"""
|
||||
self.verify_file_ext(filelike, {".pdf"})
|
||||
file_info = await self.get_file_info(filelike, only_bytes=True)
|
||||
endpoint = f"{self.parse_document_endpoint}/pdf"
|
||||
resp = await self._request_parse(endpoint=endpoint, files={"file": file_info})
|
||||
data = OmniParsedResult(**resp)
|
||||
return data
|
||||
|
||||
async def parse_video(self, filelike: Union[str, bytes, Path], bytes_filename: str = None) -> dict:
|
||||
"""
|
||||
Parse video-type data (supports ".mp4", ".mkv", ".avi", ".mov").
|
||||
|
||||
Args:
|
||||
filelike: File path or file byte data.
|
||||
bytes_filename: Filename for byte data, useful for determining MIME type for the HTTP request.
|
||||
|
||||
Raises:
|
||||
ValueError: If the file extension is not allowed.
|
||||
|
||||
Returns:
|
||||
dict: JSON response data.
|
||||
"""
|
||||
self.verify_file_ext(filelike, self.ALLOWED_VIDEO_EXTENSIONS, bytes_filename)
|
||||
file_info = await self.get_file_info(filelike, bytes_filename)
|
||||
return await self._request_parse(f"{self.parse_media_endpoint}/video", files={"file": file_info})
|
||||
|
||||
async def parse_audio(self, filelike: Union[str, bytes, Path], bytes_filename: str = None) -> dict:
|
||||
"""
|
||||
Parse audio-type data (supports ".mp3", ".wav", ".aac").
|
||||
|
||||
Args:
|
||||
filelike: File path or file byte data.
|
||||
bytes_filename: Filename for byte data, useful for determining MIME type for the HTTP request.
|
||||
|
||||
Raises:
|
||||
ValueError: If the file extension is not allowed.
|
||||
|
||||
Returns:
|
||||
dict: JSON response data.
|
||||
"""
|
||||
self.verify_file_ext(filelike, self.ALLOWED_AUDIO_EXTENSIONS, bytes_filename)
|
||||
file_info = await self.get_file_info(filelike, bytes_filename)
|
||||
return await self._request_parse(f"{self.parse_media_endpoint}/audio", files={"file": file_info})
|
||||
|
||||
@staticmethod
|
||||
def verify_file_ext(filelike: Union[str, bytes, Path], allowed_file_extensions: set, bytes_filename: str = None):
|
||||
"""
|
||||
校验文件后缀
|
||||
Verify the file extension.
|
||||
|
||||
Args:
|
||||
filelike: 文件路径 or 文件字节数据
|
||||
allowed_file_extensions: 允许的文件扩展名
|
||||
bytes_filename: 当字节数据时使用这个参数校验
|
||||
filelike: File path or file byte data.
|
||||
allowed_file_extensions: Set of allowed file extensions.
|
||||
bytes_filename: Filename to use for verification when `filelike` is byte data.
|
||||
|
||||
Raises:
|
||||
ValueError
|
||||
ValueError: If the file extension is not allowed.
|
||||
|
||||
Returns:
|
||||
"""
|
||||
|
|
@ -101,7 +186,7 @@ class OmniParseClient:
|
|||
verify_file_path = bytes_filename
|
||||
|
||||
if not verify_file_path:
|
||||
# 仅传字节数据时不校验
|
||||
# Do not verify if only byte data is provided
|
||||
return
|
||||
|
||||
file_ext = os.path.splitext(verify_file_path)[1].lower()
|
||||
|
|
@ -112,28 +197,29 @@ class OmniParseClient:
|
|||
async def get_file_info(
|
||||
filelike: Union[str, bytes, Path],
|
||||
bytes_filename: str = None,
|
||||
only_bytes=True,
|
||||
only_bytes: bool = False,
|
||||
) -> Union[bytes, tuple]:
|
||||
"""
|
||||
获取文件字节信息
|
||||
Get file information.
|
||||
|
||||
Args:
|
||||
filelike: 文件数据
|
||||
bytes_filename: 通过字节数据上传需要指定文件名称,方便获取mime_type
|
||||
only_bytes: 是否只需要字节数据
|
||||
filelike: File path or file byte data.
|
||||
bytes_filename: Filename to use when uploading byte data, useful for determining MIME type.
|
||||
only_bytes: Whether to return only byte data. Default is False, which returns a tuple.
|
||||
|
||||
Raises:
|
||||
ValueError
|
||||
ValueError: If bytes_filename is not provided when filelike is bytes or if filelike is not a valid type.
|
||||
|
||||
Notes:
|
||||
由于 parse_document 支持多种文件解析,需要上传文件时指定文件的mime_type
|
||||
Since `parse_document`,`parse_video`, `parse_audio` supports parsing various file types,
|
||||
the MIME type of the file must be specified when uploading.
|
||||
|
||||
Returns:
|
||||
[bytes, tuple]
|
||||
Returns: [bytes, tuple]
|
||||
Returns bytes if only_bytes is True, otherwise returns a tuple (filename, file_bytes, mime_type).
|
||||
"""
|
||||
if isinstance(filelike, (str, Path)):
|
||||
filename = os.path.basename(str(filelike))
|
||||
async with aiofiles.open(filelike, "rb") as file:
|
||||
file_bytes = await file.read()
|
||||
file_bytes = await aread_bin(filelike)
|
||||
|
||||
if only_bytes:
|
||||
return file_bytes
|
||||
|
|
@ -150,60 +236,3 @@ class OmniParseClient:
|
|||
return bytes_filename, filelike, mime_type
|
||||
else:
|
||||
raise ValueError("filelike must be a string (file path) or bytes.")
|
||||
|
||||
async def parse_document(self, filelike: Union[str, bytes, Path], bytes_filename: str = None) -> OmniParsedResult:
|
||||
"""
|
||||
解析文档类型数据(支持 ".pdf", ".ppt", ".pptx", ".doc", ".docs")
|
||||
Args:
|
||||
filelike: 文件路径 or 文件字节数据
|
||||
bytes_filename: 字节数据名称,方便获取mime_type 用于httpx请求
|
||||
|
||||
Raises
|
||||
ValueError
|
||||
|
||||
Returns:
|
||||
OmniParsedResult
|
||||
"""
|
||||
self.verify_file_ext(filelike, self.ALLOWED_DOCUMENT_EXTENSIONS, bytes_filename)
|
||||
file_info = await self.get_file_info(filelike, bytes_filename, only_bytes=False)
|
||||
resp = await self._request_parse(self.parse_document_endpoint, files={"file": file_info})
|
||||
data = OmniParsedResult(**resp)
|
||||
return data
|
||||
|
||||
async def parse_pdf(self, filelike: Union[str, bytes, Path]) -> OmniParsedResult:
|
||||
"""
|
||||
解析PDF文档
|
||||
Args:
|
||||
filelike: 文件路径 or 文件字节数据
|
||||
|
||||
Raises
|
||||
ValueError
|
||||
|
||||
Returns:
|
||||
OmniParsedResult
|
||||
"""
|
||||
self.verify_file_ext(filelike, {".pdf"})
|
||||
file_info = await self.get_file_info(filelike)
|
||||
endpoint = f"{self.parse_document_endpoint}/pdf"
|
||||
resp = await self._request_parse(endpoint=endpoint, files={"file": file_info})
|
||||
data = OmniParsedResult(**resp)
|
||||
return data
|
||||
|
||||
async def parse_video(self, filelike: Union[str, bytes, Path], bytes_filename: str = None) -> dict:
|
||||
"""解析视频"""
|
||||
self.verify_file_ext(filelike, self.ALLOWED_VIDEO_EXTENSIONS, bytes_filename)
|
||||
file_info = await self.get_file_info(filelike, bytes_filename, only_bytes=False)
|
||||
return await self._request_parse(f"{self.parse_media_endpoint}/video", files={"file": file_info})
|
||||
|
||||
async def parse_audio(self, filelike: Union[str, bytes, Path], bytes_filename: str = None) -> dict:
|
||||
"""解析音频"""
|
||||
self.verify_file_ext(filelike, self.ALLOWED_AUDIO_EXTENSIONS, bytes_filename)
|
||||
file_info = await self.get_file_info(filelike, bytes_filename, only_bytes=False)
|
||||
return await self._request_parse(f"{self.parse_media_endpoint}/audio", files={"file": file_info})
|
||||
|
||||
async def parse_website(self, url: str) -> dict:
|
||||
"""
|
||||
解析网站
|
||||
fixme:官方api还存在问题
|
||||
"""
|
||||
return await self._request_parse(f"{self.parse_website_endpoint}/parse", params={"url": url})
|
||||
|
|
@ -7,6 +7,7 @@ from llama_index.core.llms import MockLLM
|
|||
from llama_index.core.schema import Document, NodeWithScore, TextNode
|
||||
|
||||
from metagpt.rag.engines import SimpleEngine
|
||||
from metagpt.rag.parser import OmniParse
|
||||
from metagpt.rag.retrievers import SimpleHybridRetriever
|
||||
from metagpt.rag.retrievers.base import ModifiableRAGRetriever, PersistableRAGRetriever
|
||||
from metagpt.rag.schema import BM25RetrieverConfig, ObjectNode
|
||||
|
|
@ -39,7 +40,7 @@ class TestSimpleEngine:
|
|||
|
||||
@pytest.fixture
|
||||
def mock_get_file_extractor(self, mocker):
|
||||
return mocker.patch("metagpt.rag.engines.simple.SimpleEngine.get_file_extractor")
|
||||
return mocker.patch("metagpt.rag.engines.simple.SimpleEngine._get_file_extractor")
|
||||
|
||||
def test_from_docs(
|
||||
self,
|
||||
|
|
@ -307,3 +308,17 @@ class TestSimpleEngine:
|
|||
# Assert
|
||||
assert "obj" in node.node.metadata
|
||||
assert node.node.metadata["obj"] == expected_obj
|
||||
|
||||
def test_get_file_extractor(self, mocker):
|
||||
# mock no omniparse config
|
||||
mock_omniparse_config = mocker.patch("metagpt.rag.engines.simple.config.omniparse", autospec=True)
|
||||
mock_omniparse_config.base_url = ""
|
||||
|
||||
file_extractor = SimpleEngine._get_file_extractor()
|
||||
assert file_extractor == {}
|
||||
|
||||
# mock have omniparse config
|
||||
mock_omniparse_config.base_url = "http://localhost:8000"
|
||||
file_extractor = SimpleEngine._get_file_extractor(file_type=".pdf")
|
||||
assert ".pdf" in file_extractor
|
||||
assert isinstance(file_extractor[".pdf"], OmniParse)
|
||||
|
|
|
|||
|
|
@ -1,32 +1,118 @@
|
|||
import pytest
|
||||
from llama_index.core import Document
|
||||
|
||||
from metagpt.const import EXAMPLE_DATA_PATH
|
||||
from metagpt.rag.parser.omniparse import OmniParseClient
|
||||
from metagpt.rag.schema import OmniParsedResult
|
||||
from metagpt.rag.parser import OmniParse
|
||||
from metagpt.rag.schema import (
|
||||
OmniParsedResult,
|
||||
OmniParseOptions,
|
||||
OmniParseType,
|
||||
ParseResultType,
|
||||
)
|
||||
from metagpt.utils.omniparse_client import OmniParseClient
|
||||
|
||||
# test data
|
||||
TEST_DOCX = EXAMPLE_DATA_PATH / "omniparse/test01.docx"
|
||||
TEST_PDF = EXAMPLE_DATA_PATH / "omniparse/test02.pdf"
|
||||
TEST_VIDEO = EXAMPLE_DATA_PATH / "omniparse/test03.mp4"
|
||||
TEST_AUDIO = EXAMPLE_DATA_PATH / "omniparse/test04.mp3"
|
||||
|
||||
|
||||
class TestOmniParseClient:
|
||||
parse_client = OmniParseClient()
|
||||
|
||||
# test data
|
||||
TEST_DOCX = EXAMPLE_DATA_PATH / "parse/test01.docx"
|
||||
TEST_PDF = EXAMPLE_DATA_PATH / "parse/test02.pdf"
|
||||
TEST_VIDEO = EXAMPLE_DATA_PATH / "parse/test03.mp4"
|
||||
TEST_AUDIO = EXAMPLE_DATA_PATH / "parse/test04.mp3"
|
||||
|
||||
@pytest.fixture
|
||||
def request_parse(self, mocker):
|
||||
def mock_request_parse(self, mocker):
|
||||
return mocker.patch("metagpt.rag.parser.omniparse.OmniParseClient._request_parse")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_pdf(self, request_parse):
|
||||
async def test_parse_pdf(self, mock_request_parse):
|
||||
mock_content = "#test title\ntest content"
|
||||
mock_parsed_ret = OmniParsedResult(text=mock_content, markdown=mock_content)
|
||||
request_parse.return_value = mock_parsed_ret.model_dump()
|
||||
parse_ret = await self.parse_client.parse_pdf(self.TEST_PDF)
|
||||
mock_request_parse.return_value = mock_parsed_ret.model_dump()
|
||||
parse_ret = await self.parse_client.parse_pdf(TEST_PDF)
|
||||
assert parse_ret == mock_parsed_ret
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_document(self, mock_request_parse):
|
||||
mock_content = "#test title\ntest_parse_document"
|
||||
mock_parsed_ret = OmniParsedResult(text=mock_content, markdown=mock_content)
|
||||
mock_request_parse.return_value = mock_parsed_ret.model_dump()
|
||||
|
||||
with open(TEST_DOCX, "rb") as f:
|
||||
file_bytes = f.read()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
# bytes data must provide bytes_filename
|
||||
await self.parse_client.parse_document(file_bytes)
|
||||
|
||||
parse_ret = await self.parse_client.parse_document(file_bytes, bytes_filename="test.docx")
|
||||
assert parse_ret == mock_parsed_ret
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_video(self, mock_request_parse):
|
||||
mock_content = "#test title\ntest_parse_video"
|
||||
mock_request_parse.return_value = {
|
||||
"text": mock_content,
|
||||
"metadata": {},
|
||||
}
|
||||
with pytest.raises(ValueError):
|
||||
# Wrong file extension test
|
||||
await self.parse_client.parse_video(TEST_DOCX)
|
||||
|
||||
parse_ret = await self.parse_client.parse_video(TEST_VIDEO)
|
||||
assert "text" in parse_ret and "metadata" in parse_ret
|
||||
assert parse_ret["text"] == mock_content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_audio(self, mock_request_parse):
|
||||
mock_content = "#test title\ntest_parse_audio"
|
||||
mock_request_parse.return_value = {
|
||||
"text": mock_content,
|
||||
"metadata": {},
|
||||
}
|
||||
parse_ret = await self.parse_client.parse_audio(TEST_AUDIO)
|
||||
assert "text" in parse_ret and "metadata" in parse_ret
|
||||
assert parse_ret["text"] == mock_content
|
||||
|
||||
|
||||
class TestOmniParse:
|
||||
def test_load_data(self):
|
||||
pass
|
||||
@pytest.fixture
|
||||
def mock_omniparse(self):
|
||||
parser = OmniParse(
|
||||
parse_options=OmniParseOptions(
|
||||
parse_type=OmniParseType.PDF,
|
||||
result_type=ParseResultType.MD,
|
||||
max_timeout=120,
|
||||
num_workers=3,
|
||||
)
|
||||
)
|
||||
return parser
|
||||
|
||||
@pytest.fixture
|
||||
def mock_request_parse(self, mocker):
|
||||
return mocker.patch("metagpt.rag.parser.omniparse.OmniParseClient._request_parse")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_data(self, mock_omniparse, mock_request_parse):
|
||||
# mock
|
||||
mock_content = "#test title\ntest content"
|
||||
mock_parsed_ret = OmniParsedResult(text=mock_content, markdown=mock_content)
|
||||
mock_request_parse.return_value = mock_parsed_ret.model_dump()
|
||||
|
||||
# single file
|
||||
documents = mock_omniparse.load_data(file_path=TEST_PDF)
|
||||
doc = documents[0]
|
||||
assert isinstance(doc, Document)
|
||||
assert doc.text == mock_parsed_ret.text == mock_parsed_ret.markdown
|
||||
|
||||
# multi files
|
||||
file_paths = [TEST_DOCX, TEST_PDF]
|
||||
mock_omniparse.parse_type = OmniParseType.DOCUMENT
|
||||
documents = await mock_omniparse.aload_data(file_path=file_paths)
|
||||
doc = documents[0]
|
||||
|
||||
# assert
|
||||
assert isinstance(doc, Document)
|
||||
assert len(documents) == len(file_paths)
|
||||
assert doc.text == mock_parsed_ret.text == mock_parsed_ret.markdown
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue