mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-04-27 09:46:24 +02:00
代码优化
This commit is contained in:
parent
22b9990ccf
commit
5287e024c5
8 changed files with 106 additions and 34 deletions
|
|
@ -1,5 +1,7 @@
|
|||
import mimetypes
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import aiofiles
|
||||
import httpx
|
||||
from typing import Union
|
||||
|
|
@ -31,37 +33,72 @@ class OmniParseClient:
|
|||
self.parse_website_endpoint = "/parse_website"
|
||||
self.parse_document_endpoint = "/parse_document"
|
||||
|
||||
async def __request_parse(self, endpoint: str, files: dict = None, json: dict = None) -> dict:
|
||||
async def __request_parse(
|
||||
self, endpoint: str, method: str = "POST",
|
||||
files: dict = None, params: dict = None,
|
||||
data: dict = None, json: dict = None,
|
||||
headers: dict = None, **kwargs,
|
||||
) -> dict:
|
||||
"""
|
||||
请求api解析文档
|
||||
|
||||
Args:
|
||||
endpoint (str): API endpoint.
|
||||
files (dict, optional): 请求文件数据.
|
||||
params (dict, optional): 查询字符串参数.
|
||||
data (dict, optional): 请求体数据.
|
||||
json (dict, optional): 请求json数据.
|
||||
headers (dict, optional): 请求头数据.
|
||||
**kwargs: 其他 httpx.AsyncClient.request() 关键字参数
|
||||
|
||||
Returns:
|
||||
dict: 响应的json数据
|
||||
"""
|
||||
url = f"{self.base_url}{endpoint}"
|
||||
headers = {"Authorization": f"Bearer {self.api_key}"} if self.api_key else {}
|
||||
method = method.upper()
|
||||
headers = headers or {}
|
||||
_headers = {"Authorization": f"Bearer {self.api_key}"} if self.api_key else {}
|
||||
headers.update(**_headers)
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(url, files=files, json=json, headers=headers, timeout=self.max_timeout)
|
||||
response = await client.request(
|
||||
url=url, method=method,
|
||||
files=files, params=params, json=json, data=data,
|
||||
headers=headers, timeout=self.max_timeout, **kwargs,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
@staticmethod
|
||||
def verify_file_ext(filelike: Union[str, bytes], allowed_file_extensions: set):
|
||||
"""校验文件后缀"""
|
||||
if not filelike or isinstance(filelike, bytes):
|
||||
return
|
||||
file_ext = os.path.splitext(filelike)[1].lower()
|
||||
if file_ext not in allowed_file_extensions:
|
||||
raise ValueError("File extension must be one of {}".format(allowed_file_extensions))
|
||||
def verify_file_ext(filelike: Union[str, bytes, Path], allowed_file_extensions: set, bytes_filename: str = None):
|
||||
"""
|
||||
校验文件后缀
|
||||
Args:
|
||||
filelike: 文件路径 or 文件字节数据
|
||||
allowed_file_extensions: 允许的文件扩展名
|
||||
bytes_filename: 当字节数据时使用这个参数校验
|
||||
|
||||
Raises:
|
||||
ValueError
|
||||
|
||||
Returns:
|
||||
"""
|
||||
verify_file_path = None
|
||||
if isinstance(filelike, (str, Path)):
|
||||
verify_file_path = str(filelike)
|
||||
elif isinstance(filelike, bytes) and bytes_filename:
|
||||
verify_file_path = bytes_filename
|
||||
|
||||
if not verify_file_path:
|
||||
# 仅传字节数据时不校验
|
||||
return
|
||||
|
||||
file_ext = os.path.splitext(verify_file_path)[1].lower()
|
||||
if file_ext not in allowed_file_extensions:
|
||||
raise ValueError(f"Not allowed {file_ext} File extension must be one of {allowed_file_extensions}")
|
||||
|
||||
@staticmethod
|
||||
async def get_file_info(
|
||||
self,
|
||||
filelike: Union[str, bytes],
|
||||
filelike: Union[str, bytes, Path],
|
||||
bytes_filename: str = None,
|
||||
only_bytes=True,
|
||||
) -> Union[bytes, tuple]:
|
||||
|
|
@ -72,14 +109,17 @@ class OmniParseClient:
|
|||
bytes_filename: 通过字节数据上传需要指定文件名称,方便获取mime_type
|
||||
only_bytes: 是否只需要字节数据
|
||||
|
||||
Raises:
|
||||
ValueError
|
||||
|
||||
Notes:
|
||||
由于 parse_document 支持多种文件解析,需要上传文件时指定文件的mime_type
|
||||
|
||||
Returns:
|
||||
[bytes, tuple]
|
||||
"""
|
||||
if isinstance(filelike, str):
|
||||
filename = os.path.basename(filelike)
|
||||
if isinstance(filelike, (str, Path)):
|
||||
filename = os.path.basename(str(filelike))
|
||||
async with aiofiles.open(filelike, 'rb') as file:
|
||||
file_bytes = await file.read()
|
||||
|
||||
|
|
@ -99,7 +139,7 @@ class OmniParseClient:
|
|||
else:
|
||||
raise ValueError("filelike must be a string (file path) or bytes.")
|
||||
|
||||
async def parse_document(self, filelike: Union[str, bytes], bytes_filename: str = None) -> OmniParsedResult:
|
||||
async def parse_document(self, filelike: Union[str, bytes, Path], bytes_filename: str = None) -> OmniParsedResult:
|
||||
"""
|
||||
解析文档类型数据(支持 ".pdf", ".ppt", ".pptx", ".doc", ".docs")
|
||||
Args:
|
||||
|
|
@ -112,13 +152,13 @@ class OmniParseClient:
|
|||
Returns:
|
||||
OmniParsedResult
|
||||
"""
|
||||
self.verify_file_ext(filelike, self.ALLOWED_DOCUMENT_EXTENSIONS)
|
||||
self.verify_file_ext(filelike, self.ALLOWED_DOCUMENT_EXTENSIONS, bytes_filename)
|
||||
file_info = await self.get_file_info(filelike, bytes_filename, only_bytes=False)
|
||||
resp = await self.__request_parse(self.parse_document_endpoint, files={'file': file_info})
|
||||
data = OmniParsedResult(**resp)
|
||||
return data
|
||||
|
||||
async def parse_pdf(self, filelike: Union[str, bytes]) -> OmniParsedResult:
|
||||
async def parse_pdf(self, filelike: Union[str, bytes, Path]) -> OmniParsedResult:
|
||||
"""
|
||||
解析PDF文档
|
||||
Args:
|
||||
|
|
@ -137,18 +177,21 @@ class OmniParseClient:
|
|||
data = OmniParsedResult(**resp)
|
||||
return data
|
||||
|
||||
async def parse_video(self, filelike: Union[str, bytes], bytes_filename: str = None) -> dict:
|
||||
async def parse_video(self, filelike: Union[str, bytes, Path], bytes_filename: str = None) -> dict:
|
||||
"""解析视频"""
|
||||
self.verify_file_ext(filelike, self.ALLOWED_VIDEO_EXTENSIONS)
|
||||
self.verify_file_ext(filelike, self.ALLOWED_VIDEO_EXTENSIONS, bytes_filename)
|
||||
file_info = await self.get_file_info(filelike, bytes_filename, only_bytes=False)
|
||||
return await self.__request_parse(f"{self.parse_media_endpoint}/video", files={'file': file_info})
|
||||
|
||||
async def parse_audio(self, filelike: Union[str, bytes], bytes_filename: str = None) -> dict:
|
||||
async def parse_audio(self, filelike: Union[str, bytes, Path], bytes_filename: str = None) -> dict:
|
||||
"""解析音频"""
|
||||
self.verify_file_ext(filelike, self.ALLOWED_AUDIO_EXTENSIONS)
|
||||
self.verify_file_ext(filelike, self.ALLOWED_AUDIO_EXTENSIONS, bytes_filename)
|
||||
file_info = await self.get_file_info(filelike, bytes_filename, only_bytes=False)
|
||||
return await self.__request_parse(f"{self.parse_media_endpoint}/audio", files={'file': file_info})
|
||||
|
||||
async def parse_website(self, url: str) -> dict:
|
||||
"""解析网站"""
|
||||
return await self.__request_parse(self.parse_website_endpoint, json={'url': url})
|
||||
"""
|
||||
解析网站
|
||||
fixme:官方api还存在问题
|
||||
"""
|
||||
return await self.__request_parse(f"{self.parse_website_endpoint}/parse", params={'url': url})
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import asyncio
|
||||
from fileinput import FileInput
|
||||
from pathlib import Path
|
||||
from typing import List, Union, Optional
|
||||
|
||||
from llama_index.core import Document
|
||||
|
|
@ -10,6 +11,7 @@ from llama_parse import ResultType
|
|||
from metagpt.rag.parser.omniparse.client import OmniParseClient
|
||||
from metagpt.rag.schema import OmniParseOptions, OmniParseType
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.async_helper import NestAsyncio
|
||||
|
||||
|
||||
class OmniParse(BaseReader):
|
||||
|
|
@ -46,7 +48,7 @@ class OmniParse(BaseReader):
|
|||
|
||||
async def _aload_data(
|
||||
self,
|
||||
file_path: Union[str, bytes],
|
||||
file_path: Union[str, bytes, Path],
|
||||
extra_info: Optional[dict] = None,
|
||||
) -> List[Document]:
|
||||
try:
|
||||
|
|
@ -78,7 +80,7 @@ class OmniParse(BaseReader):
|
|||
extra_info: Optional[dict] = None,
|
||||
) -> List[Document]:
|
||||
docs = []
|
||||
if isinstance(file_path, (str, bytes)):
|
||||
if isinstance(file_path, (str, bytes, Path)):
|
||||
# 处理单个
|
||||
docs = await self._aload_data(file_path, extra_info)
|
||||
elif isinstance(file_path, list):
|
||||
|
|
@ -94,4 +96,5 @@ class OmniParse(BaseReader):
|
|||
extra_info: Optional[dict] = None,
|
||||
) -> List[Document]:
|
||||
"""Load data from the input path."""
|
||||
NestAsyncio.apply_once() # 兼容异步嵌套调用
|
||||
return asyncio.run(self.aload_data(file_path, extra_info))
|
||||
|
|
|
|||
|
|
@ -229,7 +229,7 @@ class OmniParseOptions(BaseModel):
|
|||
parse_type: OmniParseType = Field(default=OmniParseType.DOCUMENT, description="OmniParse解析类型,默认文档类型")
|
||||
max_timeout: Optional[int] = Field(default=120, description="OmniParse服务请求最大超时")
|
||||
num_workers: int = Field(
|
||||
default=4,
|
||||
default=5,
|
||||
gt=0,
|
||||
lt=10,
|
||||
description="多文件列表时并发请求数量",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue