mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-08 15:05:17 +02:00
单测相关
This commit is contained in:
parent
5287e024c5
commit
79334de5a4
4 changed files with 82 additions and 27 deletions
|
|
@ -1,10 +1,10 @@
|
|||
import mimetypes
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import aiofiles
|
||||
import httpx
|
||||
from typing import Union
|
||||
|
||||
from metagpt.rag.schema import OmniParsedResult
|
||||
|
||||
|
|
@ -14,6 +14,7 @@ class OmniParseClient:
|
|||
OmniParse Server Client
|
||||
OmniParse API Docs: https://docs.cognitivelab.in/api
|
||||
"""
|
||||
|
||||
ALLOWED_DOCUMENT_EXTENSIONS = {".pdf", ".ppt", ".pptx", ".doc", ".docx"}
|
||||
ALLOWED_AUDIO_EXTENSIONS = {".mp3", ".wav", ".aac"}
|
||||
ALLOWED_VIDEO_EXTENSIONS = {".mp4", ".mkv", ".avi", ".mov"}
|
||||
|
|
@ -33,11 +34,16 @@ class OmniParseClient:
|
|||
self.parse_website_endpoint = "/parse_website"
|
||||
self.parse_document_endpoint = "/parse_document"
|
||||
|
||||
async def __request_parse(
|
||||
self, endpoint: str, method: str = "POST",
|
||||
files: dict = None, params: dict = None,
|
||||
data: dict = None, json: dict = None,
|
||||
headers: dict = None, **kwargs,
|
||||
async def _request_parse(
|
||||
self,
|
||||
endpoint: str,
|
||||
method: str = "POST",
|
||||
files: dict = None,
|
||||
params: dict = None,
|
||||
data: dict = None,
|
||||
json: dict = None,
|
||||
headers: dict = None,
|
||||
**kwargs,
|
||||
) -> dict:
|
||||
"""
|
||||
请求api解析文档
|
||||
|
|
@ -61,9 +67,15 @@ class OmniParseClient:
|
|||
headers.update(**_headers)
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.request(
|
||||
url=url, method=method,
|
||||
files=files, params=params, json=json, data=data,
|
||||
headers=headers, timeout=self.max_timeout, **kwargs,
|
||||
url=url,
|
||||
method=method,
|
||||
files=files,
|
||||
params=params,
|
||||
json=json,
|
||||
data=data,
|
||||
headers=headers,
|
||||
timeout=self.max_timeout,
|
||||
**kwargs,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
|
@ -98,9 +110,9 @@ class OmniParseClient:
|
|||
|
||||
@staticmethod
|
||||
async def get_file_info(
|
||||
filelike: Union[str, bytes, Path],
|
||||
bytes_filename: str = None,
|
||||
only_bytes=True,
|
||||
filelike: Union[str, bytes, Path],
|
||||
bytes_filename: str = None,
|
||||
only_bytes=True,
|
||||
) -> Union[bytes, tuple]:
|
||||
"""
|
||||
获取文件字节信息
|
||||
|
|
@ -120,7 +132,7 @@ class OmniParseClient:
|
|||
"""
|
||||
if isinstance(filelike, (str, Path)):
|
||||
filename = os.path.basename(str(filelike))
|
||||
async with aiofiles.open(filelike, 'rb') as file:
|
||||
async with aiofiles.open(filelike, "rb") as file:
|
||||
file_bytes = await file.read()
|
||||
|
||||
if only_bytes:
|
||||
|
|
@ -154,7 +166,7 @@ class OmniParseClient:
|
|||
"""
|
||||
self.verify_file_ext(filelike, self.ALLOWED_DOCUMENT_EXTENSIONS, bytes_filename)
|
||||
file_info = await self.get_file_info(filelike, bytes_filename, only_bytes=False)
|
||||
resp = await self.__request_parse(self.parse_document_endpoint, files={'file': file_info})
|
||||
resp = await self._request_parse(self.parse_document_endpoint, files={"file": file_info})
|
||||
data = OmniParsedResult(**resp)
|
||||
return data
|
||||
|
||||
|
|
@ -173,7 +185,7 @@ class OmniParseClient:
|
|||
self.verify_file_ext(filelike, {".pdf"})
|
||||
file_info = await self.get_file_info(filelike)
|
||||
endpoint = f"{self.parse_document_endpoint}/pdf"
|
||||
resp = await self.__request_parse(endpoint=endpoint, files={'file': file_info})
|
||||
resp = await self._request_parse(endpoint=endpoint, files={"file": file_info})
|
||||
data = OmniParsedResult(**resp)
|
||||
return data
|
||||
|
||||
|
|
@ -181,17 +193,17 @@ class OmniParseClient:
|
|||
"""解析视频"""
|
||||
self.verify_file_ext(filelike, self.ALLOWED_VIDEO_EXTENSIONS, bytes_filename)
|
||||
file_info = await self.get_file_info(filelike, bytes_filename, only_bytes=False)
|
||||
return await self.__request_parse(f"{self.parse_media_endpoint}/video", files={'file': file_info})
|
||||
return await self._request_parse(f"{self.parse_media_endpoint}/video", files={"file": file_info})
|
||||
|
||||
async def parse_audio(self, filelike: Union[str, bytes, Path], bytes_filename: str = None) -> dict:
|
||||
"""解析音频"""
|
||||
self.verify_file_ext(filelike, self.ALLOWED_AUDIO_EXTENSIONS, bytes_filename)
|
||||
file_info = await self.get_file_info(filelike, bytes_filename, only_bytes=False)
|
||||
return await self.__request_parse(f"{self.parse_media_endpoint}/audio", files={'file': file_info})
|
||||
return await self._request_parse(f"{self.parse_media_endpoint}/audio", files={"file": file_info})
|
||||
|
||||
async def parse_website(self, url: str) -> dict:
|
||||
"""
|
||||
解析网站
|
||||
fixme:官方api还存在问题
|
||||
"""
|
||||
return await self.__request_parse(f"{self.parse_website_endpoint}/parse", params={'url': url})
|
||||
return await self._request_parse(f"{self.parse_website_endpoint}/parse", params={"url": url})
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
"""RAG schemas."""
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, ClassVar, Literal, Optional, Union, List
|
||||
from typing import Any, ClassVar, List, Literal, Optional, Union
|
||||
|
||||
# from chromadb.api.types import CollectionMetadata
|
||||
from chromadb.api.types import CollectionMetadata
|
||||
from llama_index.core.embeddings import BaseEmbedding
|
||||
from llama_index.core.indices.base import BaseIndex
|
||||
from llama_index.core.schema import TextNode
|
||||
|
|
@ -68,9 +68,9 @@ class ChromaRetrieverConfig(IndexRetrieverConfig):
|
|||
|
||||
persist_path: Union[str, Path] = Field(default="./chroma_db", description="The directory to save data.")
|
||||
collection_name: str = Field(default="metagpt", description="The name of the collection.")
|
||||
# metadata: Optional[CollectionMetadata] = Field(
|
||||
# default=None, description="Optional metadata to associate with the collection"
|
||||
# )
|
||||
metadata: Optional[CollectionMetadata] = Field(
|
||||
default=None, description="Optional metadata to associate with the collection"
|
||||
)
|
||||
|
||||
|
||||
class ElasticsearchStoreConfig(BaseModel):
|
||||
|
|
@ -166,9 +166,9 @@ class ChromaIndexConfig(VectorIndexConfig):
|
|||
"""Config for chroma-based index."""
|
||||
|
||||
collection_name: str = Field(default="metagpt", description="The name of the collection.")
|
||||
# metadata: Optional[CollectionMetadata] = Field(
|
||||
# default=None, description="Optional metadata to associate with the collection"
|
||||
# )
|
||||
metadata: Optional[CollectionMetadata] = Field(
|
||||
default=None, description="Optional metadata to associate with the collection"
|
||||
)
|
||||
|
||||
|
||||
class BM25IndexConfig(BaseIndexConfig):
|
||||
|
|
@ -219,12 +219,14 @@ class ObjectNode(TextNode):
|
|||
|
||||
class OmniParseType(str, Enum):
|
||||
"""OmniParse解析类型"""
|
||||
|
||||
PDF = "PDF"
|
||||
DOCUMENT = "DOCUMENT"
|
||||
|
||||
|
||||
class OmniParseOptions(BaseModel):
|
||||
"""OmniParse可选配置"""
|
||||
|
||||
result_type: ResultType = Field(default=ResultType.MD, description="OmniParse解析返回的结果类型")
|
||||
parse_type: OmniParseType = Field(default=OmniParseType.DOCUMENT, description="OmniParse解析类型,默认文档类型")
|
||||
max_timeout: Optional[int] = Field(default=120, description="OmniParse服务请求最大超时")
|
||||
|
|
|
|||
|
|
@ -37,6 +37,10 @@ class TestSimpleEngine:
|
|||
def mock_get_response_synthesizer(self, mocker):
|
||||
return mocker.patch("metagpt.rag.engines.simple.get_response_synthesizer")
|
||||
|
||||
@pytest.fixture
|
||||
def mock_get_file_extractor(self, mocker):
|
||||
return mocker.patch("metagpt.rag.engines.simple.SimpleEngine.get_file_extractor")
|
||||
|
||||
def test_from_docs(
|
||||
self,
|
||||
mocker,
|
||||
|
|
@ -44,6 +48,7 @@ class TestSimpleEngine:
|
|||
mock_get_retriever,
|
||||
mock_get_rankers,
|
||||
mock_get_response_synthesizer,
|
||||
mock_get_file_extractor,
|
||||
):
|
||||
# Mock
|
||||
mock_simple_directory_reader.return_value.load_data.return_value = [
|
||||
|
|
@ -53,6 +58,8 @@ class TestSimpleEngine:
|
|||
mock_get_retriever.return_value = mocker.MagicMock()
|
||||
mock_get_rankers.return_value = [mocker.MagicMock()]
|
||||
mock_get_response_synthesizer.return_value = mocker.MagicMock()
|
||||
file_extractor = mocker.MagicMock()
|
||||
mock_get_file_extractor.return_value = file_extractor
|
||||
|
||||
# Setup
|
||||
input_dir = "test_dir"
|
||||
|
|
@ -75,7 +82,9 @@ class TestSimpleEngine:
|
|||
)
|
||||
|
||||
# Assert
|
||||
mock_simple_directory_reader.assert_called_once_with(input_dir=input_dir, input_files=input_files)
|
||||
mock_simple_directory_reader.assert_called_once_with(
|
||||
input_dir=input_dir, input_files=input_files, file_extractor=file_extractor
|
||||
)
|
||||
mock_get_retriever.assert_called_once()
|
||||
mock_get_rankers.assert_called_once()
|
||||
mock_get_response_synthesizer.assert_called_once_with(llm=llm)
|
||||
|
|
|
|||
32
tests/metagpt/rag/parser/test_omniparse.py
Normal file
32
tests/metagpt/rag/parser/test_omniparse.py
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.const import EXAMPLE_DATA_PATH
|
||||
from metagpt.rag.parser.omniparse import OmniParseClient
|
||||
from metagpt.rag.schema import OmniParsedResult
|
||||
|
||||
|
||||
class TestOmniParseClient:
|
||||
parse_client = OmniParseClient()
|
||||
|
||||
# test data
|
||||
TEST_DOCX = EXAMPLE_DATA_PATH / "parse/test01.docx"
|
||||
TEST_PDF = EXAMPLE_DATA_PATH / "parse/test02.pdf"
|
||||
TEST_VIDEO = EXAMPLE_DATA_PATH / "parse/test03.mp4"
|
||||
TEST_AUDIO = EXAMPLE_DATA_PATH / "parse/test04.mp3"
|
||||
|
||||
@pytest.fixture
|
||||
def request_parse(self, mocker):
|
||||
return mocker.patch("metagpt.rag.parser.omniparse.OmniParseClient._request_parse")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_pdf(self, request_parse):
|
||||
mock_content = "#test title\ntest content"
|
||||
mock_parsed_ret = OmniParsedResult(text=mock_content, markdown=mock_content)
|
||||
request_parse.return_value = mock_parsed_ret.model_dump()
|
||||
parse_ret = await self.parse_client.parse_pdf(self.TEST_PDF)
|
||||
assert parse_ret == mock_parsed_ret
|
||||
|
||||
|
||||
class TestOmniParse:
|
||||
def test_load_data(self):
|
||||
pass
|
||||
Loading…
Add table
Add a link
Reference in a new issue