代码优化

This commit is contained in:
liuminhui 2024-07-19 13:22:10 +08:00
parent 22b9990ccf
commit 5287e024c5
8 changed files with 106 additions and 34 deletions

View file

@ -1,5 +1,7 @@
import mimetypes
import os
from pathlib import Path
import aiofiles
import httpx
from typing import Union
@ -31,37 +33,72 @@ class OmniParseClient:
self.parse_website_endpoint = "/parse_website"
self.parse_document_endpoint = "/parse_document"
async def __request_parse(self, endpoint: str, files: dict = None, json: dict = None) -> dict:
async def __request_parse(
self, endpoint: str, method: str = "POST",
files: dict = None, params: dict = None,
data: dict = None, json: dict = None,
headers: dict = None, **kwargs,
) -> dict:
"""
请求api解析文档
Args:
endpoint (str): API endpoint.
files (dict, optional): 请求文件数据.
params (dict, optional): 查询字符串参数.
data (dict, optional): 请求体数据.
json (dict, optional): 请求json数据.
headers (dict, optional): 请求头数据.
**kwargs: 其他 httpx.AsyncClient.request() 关键字参数
Returns:
dict: 响应的json数据
"""
url = f"{self.base_url}{endpoint}"
headers = {"Authorization": f"Bearer {self.api_key}"} if self.api_key else {}
method = method.upper()
headers = headers or {}
_headers = {"Authorization": f"Bearer {self.api_key}"} if self.api_key else {}
headers.update(**_headers)
async with httpx.AsyncClient() as client:
response = await client.post(url, files=files, json=json, headers=headers, timeout=self.max_timeout)
response = await client.request(
url=url, method=method,
files=files, params=params, json=json, data=data,
headers=headers, timeout=self.max_timeout, **kwargs,
)
response.raise_for_status()
return response.json()
@staticmethod
def verify_file_ext(filelike: Union[str, bytes], allowed_file_extensions: set):
"""校验文件后缀"""
if not filelike or isinstance(filelike, bytes):
return
file_ext = os.path.splitext(filelike)[1].lower()
if file_ext not in allowed_file_extensions:
raise ValueError("File extension must be one of {}".format(allowed_file_extensions))
def verify_file_ext(filelike: Union[str, bytes, Path], allowed_file_extensions: set, bytes_filename: str = None):
"""
校验文件后缀
Args:
filelike: 文件路径 or 文件字节数据
allowed_file_extensions: 允许的文件扩展名
bytes_filename: 当字节数据时使用这个参数校验
Raises:
ValueError
Returns:
"""
verify_file_path = None
if isinstance(filelike, (str, Path)):
verify_file_path = str(filelike)
elif isinstance(filelike, bytes) and bytes_filename:
verify_file_path = bytes_filename
if not verify_file_path:
# 仅传字节数据时不校验
return
file_ext = os.path.splitext(verify_file_path)[1].lower()
if file_ext not in allowed_file_extensions:
raise ValueError(f"Not allowed {file_ext} File extension must be one of {allowed_file_extensions}")
@staticmethod
async def get_file_info(
self,
filelike: Union[str, bytes],
filelike: Union[str, bytes, Path],
bytes_filename: str = None,
only_bytes=True,
) -> Union[bytes, tuple]:
@ -72,14 +109,17 @@ class OmniParseClient:
bytes_filename: 通过字节数据上传需要指定文件名称方便获取mime_type
only_bytes: 是否只需要字节数据
Raises:
ValueError
Notes:
由于 parse_document 支持多种文件解析需要上传文件时指定文件的mime_type
Returns:
[bytes, tuple]
"""
if isinstance(filelike, str):
filename = os.path.basename(filelike)
if isinstance(filelike, (str, Path)):
filename = os.path.basename(str(filelike))
async with aiofiles.open(filelike, 'rb') as file:
file_bytes = await file.read()
@ -99,7 +139,7 @@ class OmniParseClient:
else:
raise ValueError("filelike must be a string (file path) or bytes.")
async def parse_document(self, filelike: Union[str, bytes], bytes_filename: str = None) -> OmniParsedResult:
async def parse_document(self, filelike: Union[str, bytes, Path], bytes_filename: str = None) -> OmniParsedResult:
"""
解析文档类型数据支持 ".pdf", ".ppt", ".pptx", ".doc", ".docs"
Args:
@ -112,13 +152,13 @@ class OmniParseClient:
Returns:
OmniParsedResult
"""
self.verify_file_ext(filelike, self.ALLOWED_DOCUMENT_EXTENSIONS)
self.verify_file_ext(filelike, self.ALLOWED_DOCUMENT_EXTENSIONS, bytes_filename)
file_info = await self.get_file_info(filelike, bytes_filename, only_bytes=False)
resp = await self.__request_parse(self.parse_document_endpoint, files={'file': file_info})
data = OmniParsedResult(**resp)
return data
async def parse_pdf(self, filelike: Union[str, bytes]) -> OmniParsedResult:
async def parse_pdf(self, filelike: Union[str, bytes, Path]) -> OmniParsedResult:
"""
解析PDF文档
Args:
@ -137,18 +177,21 @@ class OmniParseClient:
data = OmniParsedResult(**resp)
return data
async def parse_video(self, filelike: Union[str, bytes], bytes_filename: str = None) -> dict:
async def parse_video(self, filelike: Union[str, bytes, Path], bytes_filename: str = None) -> dict:
"""解析视频"""
self.verify_file_ext(filelike, self.ALLOWED_VIDEO_EXTENSIONS)
self.verify_file_ext(filelike, self.ALLOWED_VIDEO_EXTENSIONS, bytes_filename)
file_info = await self.get_file_info(filelike, bytes_filename, only_bytes=False)
return await self.__request_parse(f"{self.parse_media_endpoint}/video", files={'file': file_info})
async def parse_audio(self, filelike: Union[str, bytes], bytes_filename: str = None) -> dict:
async def parse_audio(self, filelike: Union[str, bytes, Path], bytes_filename: str = None) -> dict:
"""解析音频"""
self.verify_file_ext(filelike, self.ALLOWED_AUDIO_EXTENSIONS)
self.verify_file_ext(filelike, self.ALLOWED_AUDIO_EXTENSIONS, bytes_filename)
file_info = await self.get_file_info(filelike, bytes_filename, only_bytes=False)
return await self.__request_parse(f"{self.parse_media_endpoint}/audio", files={'file': file_info})
async def parse_website(self, url: str) -> dict:
"""解析网站"""
return await self.__request_parse(self.parse_website_endpoint, json={'url': url})
"""
解析网站
fixme:官方api还存在问题
"""
return await self.__request_parse(f"{self.parse_website_endpoint}/parse", params={'url': url})

View file

@ -1,5 +1,6 @@
import asyncio
from fileinput import FileInput
from pathlib import Path
from typing import List, Union, Optional
from llama_index.core import Document
@ -10,6 +11,7 @@ from llama_parse import ResultType
from metagpt.rag.parser.omniparse.client import OmniParseClient
from metagpt.rag.schema import OmniParseOptions, OmniParseType
from metagpt.logs import logger
from metagpt.utils.async_helper import NestAsyncio
class OmniParse(BaseReader):
@ -46,7 +48,7 @@ class OmniParse(BaseReader):
async def _aload_data(
self,
file_path: Union[str, bytes],
file_path: Union[str, bytes, Path],
extra_info: Optional[dict] = None,
) -> List[Document]:
try:
@ -78,7 +80,7 @@ class OmniParse(BaseReader):
extra_info: Optional[dict] = None,
) -> List[Document]:
docs = []
if isinstance(file_path, (str, bytes)):
if isinstance(file_path, (str, bytes, Path)):
# 处理单个
docs = await self._aload_data(file_path, extra_info)
elif isinstance(file_path, list):
@ -94,4 +96,5 @@ class OmniParse(BaseReader):
extra_info: Optional[dict] = None,
) -> List[Document]:
"""Load data from the input path."""
NestAsyncio.apply_once() # 兼容异步嵌套调用
return asyncio.run(self.aload_data(file_path, extra_info))

View file

@ -229,7 +229,7 @@ class OmniParseOptions(BaseModel):
parse_type: OmniParseType = Field(default=OmniParseType.DOCUMENT, description="OmniParse解析类型默认文档类型")
max_timeout: Optional[int] = Field(default=120, description="OmniParse服务请求最大超时")
num_workers: int = Field(
default=4,
default=5,
gt=0,
lt=10,
description="多文件列表时并发请求数量",