单测相关

This commit is contained in:
liuminhui 2024-07-19 18:35:02 +08:00
parent 5287e024c5
commit 79334de5a4
4 changed files with 82 additions and 27 deletions

View file

@ -1,10 +1,10 @@
import mimetypes
import os
from pathlib import Path
from typing import Union
import aiofiles
import httpx
from typing import Union
from metagpt.rag.schema import OmniParsedResult
@ -14,6 +14,7 @@ class OmniParseClient:
OmniParse Server Client
OmniParse API Docs: https://docs.cognitivelab.in/api
"""
ALLOWED_DOCUMENT_EXTENSIONS = {".pdf", ".ppt", ".pptx", ".doc", ".docx"}
ALLOWED_AUDIO_EXTENSIONS = {".mp3", ".wav", ".aac"}
ALLOWED_VIDEO_EXTENSIONS = {".mp4", ".mkv", ".avi", ".mov"}
@ -33,11 +34,16 @@ class OmniParseClient:
self.parse_website_endpoint = "/parse_website"
self.parse_document_endpoint = "/parse_document"
async def __request_parse(
self, endpoint: str, method: str = "POST",
files: dict = None, params: dict = None,
data: dict = None, json: dict = None,
headers: dict = None, **kwargs,
async def _request_parse(
self,
endpoint: str,
method: str = "POST",
files: dict = None,
params: dict = None,
data: dict = None,
json: dict = None,
headers: dict = None,
**kwargs,
) -> dict:
"""
请求api解析文档
@ -61,9 +67,15 @@ class OmniParseClient:
headers.update(**_headers)
async with httpx.AsyncClient() as client:
response = await client.request(
url=url, method=method,
files=files, params=params, json=json, data=data,
headers=headers, timeout=self.max_timeout, **kwargs,
url=url,
method=method,
files=files,
params=params,
json=json,
data=data,
headers=headers,
timeout=self.max_timeout,
**kwargs,
)
response.raise_for_status()
return response.json()
@ -98,9 +110,9 @@ class OmniParseClient:
@staticmethod
async def get_file_info(
filelike: Union[str, bytes, Path],
bytes_filename: str = None,
only_bytes=True,
filelike: Union[str, bytes, Path],
bytes_filename: str = None,
only_bytes=True,
) -> Union[bytes, tuple]:
"""
获取文件字节信息
@ -120,7 +132,7 @@ class OmniParseClient:
"""
if isinstance(filelike, (str, Path)):
filename = os.path.basename(str(filelike))
async with aiofiles.open(filelike, 'rb') as file:
async with aiofiles.open(filelike, "rb") as file:
file_bytes = await file.read()
if only_bytes:
@ -154,7 +166,7 @@ class OmniParseClient:
"""
self.verify_file_ext(filelike, self.ALLOWED_DOCUMENT_EXTENSIONS, bytes_filename)
file_info = await self.get_file_info(filelike, bytes_filename, only_bytes=False)
resp = await self.__request_parse(self.parse_document_endpoint, files={'file': file_info})
resp = await self._request_parse(self.parse_document_endpoint, files={"file": file_info})
data = OmniParsedResult(**resp)
return data
@ -173,7 +185,7 @@ class OmniParseClient:
self.verify_file_ext(filelike, {".pdf"})
file_info = await self.get_file_info(filelike)
endpoint = f"{self.parse_document_endpoint}/pdf"
resp = await self.__request_parse(endpoint=endpoint, files={'file': file_info})
resp = await self._request_parse(endpoint=endpoint, files={"file": file_info})
data = OmniParsedResult(**resp)
return data
@ -181,17 +193,17 @@ class OmniParseClient:
"""解析视频"""
self.verify_file_ext(filelike, self.ALLOWED_VIDEO_EXTENSIONS, bytes_filename)
file_info = await self.get_file_info(filelike, bytes_filename, only_bytes=False)
return await self.__request_parse(f"{self.parse_media_endpoint}/video", files={'file': file_info})
return await self._request_parse(f"{self.parse_media_endpoint}/video", files={"file": file_info})
async def parse_audio(self, filelike: Union[str, bytes, Path], bytes_filename: str = None) -> dict:
"""解析音频"""
self.verify_file_ext(filelike, self.ALLOWED_AUDIO_EXTENSIONS, bytes_filename)
file_info = await self.get_file_info(filelike, bytes_filename, only_bytes=False)
return await self.__request_parse(f"{self.parse_media_endpoint}/audio", files={'file': file_info})
return await self._request_parse(f"{self.parse_media_endpoint}/audio", files={"file": file_info})
async def parse_website(self, url: str) -> dict:
"""
解析网站
fixme:官方api还存在问题
"""
return await self.__request_parse(f"{self.parse_website_endpoint}/parse", params={'url': url})
return await self._request_parse(f"{self.parse_website_endpoint}/parse", params={"url": url})

View file

@ -1,9 +1,9 @@
"""RAG schemas."""
from enum import Enum
from pathlib import Path
from typing import Any, ClassVar, Literal, Optional, Union, List
from typing import Any, ClassVar, List, Literal, Optional, Union
# from chromadb.api.types import CollectionMetadata
from chromadb.api.types import CollectionMetadata
from llama_index.core.embeddings import BaseEmbedding
from llama_index.core.indices.base import BaseIndex
from llama_index.core.schema import TextNode
@ -68,9 +68,9 @@ class ChromaRetrieverConfig(IndexRetrieverConfig):
persist_path: Union[str, Path] = Field(default="./chroma_db", description="The directory to save data.")
collection_name: str = Field(default="metagpt", description="The name of the collection.")
# metadata: Optional[CollectionMetadata] = Field(
# default=None, description="Optional metadata to associate with the collection"
# )
metadata: Optional[CollectionMetadata] = Field(
default=None, description="Optional metadata to associate with the collection"
)
class ElasticsearchStoreConfig(BaseModel):
@ -166,9 +166,9 @@ class ChromaIndexConfig(VectorIndexConfig):
"""Config for chroma-based index."""
collection_name: str = Field(default="metagpt", description="The name of the collection.")
# metadata: Optional[CollectionMetadata] = Field(
# default=None, description="Optional metadata to associate with the collection"
# )
metadata: Optional[CollectionMetadata] = Field(
default=None, description="Optional metadata to associate with the collection"
)
class BM25IndexConfig(BaseIndexConfig):
@ -219,12 +219,14 @@ class ObjectNode(TextNode):
class OmniParseType(str, Enum):
"""OmniParse解析类型"""
PDF = "PDF"
DOCUMENT = "DOCUMENT"
class OmniParseOptions(BaseModel):
"""OmniParse可选配置"""
result_type: ResultType = Field(default=ResultType.MD, description="OmniParse解析返回的结果类型")
parse_type: OmniParseType = Field(default=OmniParseType.DOCUMENT, description="OmniParse解析类型默认文档类型")
max_timeout: Optional[int] = Field(default=120, description="OmniParse服务请求最大超时")

View file

@ -37,6 +37,10 @@ class TestSimpleEngine:
def mock_get_response_synthesizer(self, mocker):
return mocker.patch("metagpt.rag.engines.simple.get_response_synthesizer")
@pytest.fixture
def mock_get_file_extractor(self, mocker):
return mocker.patch("metagpt.rag.engines.simple.SimpleEngine.get_file_extractor")
def test_from_docs(
self,
mocker,
@ -44,6 +48,7 @@ class TestSimpleEngine:
mock_get_retriever,
mock_get_rankers,
mock_get_response_synthesizer,
mock_get_file_extractor,
):
# Mock
mock_simple_directory_reader.return_value.load_data.return_value = [
@ -53,6 +58,8 @@ class TestSimpleEngine:
mock_get_retriever.return_value = mocker.MagicMock()
mock_get_rankers.return_value = [mocker.MagicMock()]
mock_get_response_synthesizer.return_value = mocker.MagicMock()
file_extractor = mocker.MagicMock()
mock_get_file_extractor.return_value = file_extractor
# Setup
input_dir = "test_dir"
@ -75,7 +82,9 @@ class TestSimpleEngine:
)
# Assert
mock_simple_directory_reader.assert_called_once_with(input_dir=input_dir, input_files=input_files)
mock_simple_directory_reader.assert_called_once_with(
input_dir=input_dir, input_files=input_files, file_extractor=file_extractor
)
mock_get_retriever.assert_called_once()
mock_get_rankers.assert_called_once()
mock_get_response_synthesizer.assert_called_once_with(llm=llm)

View file

@ -0,0 +1,32 @@
import pytest
from metagpt.const import EXAMPLE_DATA_PATH
from metagpt.rag.parser.omniparse import OmniParseClient
from metagpt.rag.schema import OmniParsedResult
class TestOmniParseClient:
parse_client = OmniParseClient()
# test data
TEST_DOCX = EXAMPLE_DATA_PATH / "parse/test01.docx"
TEST_PDF = EXAMPLE_DATA_PATH / "parse/test02.pdf"
TEST_VIDEO = EXAMPLE_DATA_PATH / "parse/test03.mp4"
TEST_AUDIO = EXAMPLE_DATA_PATH / "parse/test04.mp3"
@pytest.fixture
def request_parse(self, mocker):
return mocker.patch("metagpt.rag.parser.omniparse.OmniParseClient._request_parse")
@pytest.mark.asyncio
async def test_parse_pdf(self, request_parse):
mock_content = "#test title\ntest content"
mock_parsed_ret = OmniParsedResult(text=mock_content, markdown=mock_content)
request_parse.return_value = mock_parsed_ret.model_dump()
parse_ret = await self.parse_client.parse_pdf(self.TEST_PDF)
assert parse_ret == mock_parsed_ret
class TestOmniParse:
def test_load_data(self):
pass