This commit is contained in:
femto 2024-10-11 16:48:08 +08:00
commit a7efa27ce0
55 changed files with 1009 additions and 72 deletions

View file

@ -6,7 +6,7 @@
@File : test_action_node.py
"""
from pathlib import Path
from typing import List, Tuple
from typing import List, Optional, Tuple
import pytest
from pydantic import BaseModel, Field, ValidationError
@ -302,6 +302,19 @@ def test_action_node_from_pydantic_and_print_everything():
assert "tasks" in code, "tasks should be in code"
def test_optional():
mapping = {
"Logic Analysis": (Optional[List[Tuple[str, str]]], Field(default=None)),
"Task list": (Optional[List[str]], None),
"Plan": (Optional[str], ""),
"Anything UNCLEAR": (Optional[str], None),
}
m = {"Anything UNCLEAR": "a"}
t = ActionNode.create_model_class("test_class_1", mapping)
t1 = t(**m)
assert t1
if __name__ == "__main__":
test_create_model_class()
test_create_model_class_with_mapping()
pytest.main([__file__, "-s"])

View file

@ -64,7 +64,7 @@ def is_subset(subset, superset) -> bool:
superset = {"prompt": "hello", "kwargs": {"temperature": 0.0, "top-p": 0.0}}
is_subset(subset, superset)
```
>>>False
"""
for key, value in subset.items():
if key not in superset:

View file

@ -7,6 +7,7 @@ from llama_index.core.llms import MockLLM
from llama_index.core.schema import Document, NodeWithScore, TextNode
from metagpt.rag.engines import SimpleEngine
from metagpt.rag.parsers import OmniParse
from metagpt.rag.retrievers import SimpleHybridRetriever
from metagpt.rag.retrievers.base import ModifiableRAGRetriever, PersistableRAGRetriever
from metagpt.rag.schema import BM25RetrieverConfig, ObjectNode
@ -37,6 +38,10 @@ class TestSimpleEngine:
def mock_get_response_synthesizer(self, mocker):
return mocker.patch("metagpt.rag.engines.simple.get_response_synthesizer")
@pytest.fixture
def mock_get_file_extractor(self, mocker):
return mocker.patch("metagpt.rag.engines.simple.SimpleEngine._get_file_extractor")
def test_from_docs(
self,
mocker,
@ -44,6 +49,7 @@ class TestSimpleEngine:
mock_get_retriever,
mock_get_rankers,
mock_get_response_synthesizer,
mock_get_file_extractor,
):
# Mock
mock_simple_directory_reader.return_value.load_data.return_value = [
@ -53,6 +59,8 @@ class TestSimpleEngine:
mock_get_retriever.return_value = mocker.MagicMock()
mock_get_rankers.return_value = [mocker.MagicMock()]
mock_get_response_synthesizer.return_value = mocker.MagicMock()
file_extractor = mocker.MagicMock()
mock_get_file_extractor.return_value = file_extractor
# Setup
input_dir = "test_dir"
@ -75,7 +83,9 @@ class TestSimpleEngine:
)
# Assert
mock_simple_directory_reader.assert_called_once_with(input_dir=input_dir, input_files=input_files)
mock_simple_directory_reader.assert_called_once_with(
input_dir=input_dir, input_files=input_files, file_extractor=file_extractor
)
mock_get_retriever.assert_called_once()
mock_get_rankers.assert_called_once()
mock_get_response_synthesizer.assert_called_once_with(llm=llm)
@ -298,3 +308,17 @@ class TestSimpleEngine:
# Assert
assert "obj" in node.node.metadata
assert node.node.metadata["obj"] == expected_obj
def test_get_file_extractor(self, mocker):
# mock no omniparse config
mock_omniparse_config = mocker.patch("metagpt.rag.engines.simple.config.omniparse", autospec=True)
mock_omniparse_config.base_url = ""
file_extractor = SimpleEngine._get_file_extractor()
assert file_extractor == {}
# mock have omniparse config
mock_omniparse_config.base_url = "http://localhost:8000"
file_extractor = SimpleEngine._get_file_extractor()
assert ".pdf" in file_extractor
assert isinstance(file_extractor[".pdf"], OmniParse)

View file

@ -0,0 +1,118 @@
import pytest
from llama_index.core import Document
from metagpt.const import EXAMPLE_DATA_PATH
from metagpt.rag.parsers import OmniParse
from metagpt.rag.schema import (
OmniParsedResult,
OmniParseOptions,
OmniParseType,
ParseResultType,
)
from metagpt.utils.omniparse_client import OmniParseClient
# test data
TEST_DOCX = EXAMPLE_DATA_PATH / "omniparse/test01.docx"
TEST_PDF = EXAMPLE_DATA_PATH / "omniparse/test02.pdf"
TEST_VIDEO = EXAMPLE_DATA_PATH / "omniparse/test03.mp4"
TEST_AUDIO = EXAMPLE_DATA_PATH / "omniparse/test04.mp3"
class TestOmniParseClient:
parse_client = OmniParseClient()
@pytest.fixture
def mock_request_parse(self, mocker):
return mocker.patch("metagpt.rag.parsers.omniparse.OmniParseClient._request_parse")
@pytest.mark.asyncio
async def test_parse_pdf(self, mock_request_parse):
mock_content = "#test title\ntest content"
mock_parsed_ret = OmniParsedResult(text=mock_content, markdown=mock_content)
mock_request_parse.return_value = mock_parsed_ret.model_dump()
parse_ret = await self.parse_client.parse_pdf(TEST_PDF)
assert parse_ret == mock_parsed_ret
@pytest.mark.asyncio
async def test_parse_document(self, mock_request_parse):
mock_content = "#test title\ntest_parse_document"
mock_parsed_ret = OmniParsedResult(text=mock_content, markdown=mock_content)
mock_request_parse.return_value = mock_parsed_ret.model_dump()
with open(TEST_DOCX, "rb") as f:
file_bytes = f.read()
with pytest.raises(ValueError):
# bytes data must provide bytes_filename
await self.parse_client.parse_document(file_bytes)
parse_ret = await self.parse_client.parse_document(file_bytes, bytes_filename="test.docx")
assert parse_ret == mock_parsed_ret
@pytest.mark.asyncio
async def test_parse_video(self, mock_request_parse):
mock_content = "#test title\ntest_parse_video"
mock_request_parse.return_value = {
"text": mock_content,
"metadata": {},
}
with pytest.raises(ValueError):
# Wrong file extension test
await self.parse_client.parse_video(TEST_DOCX)
parse_ret = await self.parse_client.parse_video(TEST_VIDEO)
assert "text" in parse_ret and "metadata" in parse_ret
assert parse_ret["text"] == mock_content
@pytest.mark.asyncio
async def test_parse_audio(self, mock_request_parse):
mock_content = "#test title\ntest_parse_audio"
mock_request_parse.return_value = {
"text": mock_content,
"metadata": {},
}
parse_ret = await self.parse_client.parse_audio(TEST_AUDIO)
assert "text" in parse_ret and "metadata" in parse_ret
assert parse_ret["text"] == mock_content
class TestOmniParse:
@pytest.fixture
def mock_omniparse(self):
parser = OmniParse(
parse_options=OmniParseOptions(
parse_type=OmniParseType.PDF,
result_type=ParseResultType.MD,
max_timeout=120,
num_workers=3,
)
)
return parser
@pytest.fixture
def mock_request_parse(self, mocker):
return mocker.patch("metagpt.rag.parsers.omniparse.OmniParseClient._request_parse")
@pytest.mark.asyncio
async def test_load_data(self, mock_omniparse, mock_request_parse):
# mock
mock_content = "#test title\ntest content"
mock_parsed_ret = OmniParsedResult(text=mock_content, markdown=mock_content)
mock_request_parse.return_value = mock_parsed_ret.model_dump()
# single file
documents = mock_omniparse.load_data(file_path=TEST_PDF)
doc = documents[0]
assert isinstance(doc, Document)
assert doc.text == mock_parsed_ret.text == mock_parsed_ret.markdown
# multi files
file_paths = [TEST_DOCX, TEST_PDF]
mock_omniparse.parse_type = OmniParseType.DOCUMENT
documents = await mock_omniparse.aload_data(file_path=file_paths)
doc = documents[0]
# assert
assert isinstance(doc, Document)
assert len(documents) == len(file_paths)
assert doc.text == mock_parsed_ret.text == mock_parsed_ret.markdown

View file

@ -5,6 +5,7 @@ import pytest
from metagpt.provider.human_provider import HumanProvider
from metagpt.roles.role import Role
from metagpt.schema import Message, UserMessage
def test_role_desc():
@ -18,5 +19,15 @@ def test_role_human(context):
assert isinstance(role.llm, HumanProvider)
@pytest.mark.asyncio
async def test_recovered():
role = Role(profile="Tester", desc="Tester", recovered=True)
role.put_message(UserMessage(content="2"))
role.latest_observed_msg = Message(content="1")
await role._observe()
await role._observe()
assert role.rc.msg_buffer.empty()
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -1,7 +1,7 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc :
import pytest
from metagpt.actions.action_node import ActionNode
from metagpt.actions.add_requirement import UserRequirement
@ -55,6 +55,7 @@ def test_environment_serdeser(context):
assert isinstance(list(environment.roles.values())[0].actions[0], ActionOK)
assert type(list(new_env.roles.values())[0].actions[0]) == ActionOK
assert type(list(new_env.roles.values())[0].actions[1]) == ActionRaise
assert list(new_env.roles.values())[0].rc.watch == role_c.rc.watch
def test_environment_serdeser_v2(context):
@ -69,6 +70,7 @@ def test_environment_serdeser_v2(context):
assert isinstance(role, ProjectManager)
assert isinstance(role.actions[0], WriteTasks)
assert isinstance(list(new_env.roles.values())[0].actions[0], WriteTasks)
assert list(new_env.roles.values())[0].rc.watch == pm.rc.watch
def test_environment_serdeser_save(context):
@ -85,3 +87,8 @@ def test_environment_serdeser_save(context):
new_env: Environment = Environment(**env_dict, context=context)
assert len(new_env.roles) == 1
assert type(list(new_env.roles.values())[0].actions[0]) == ActionOK
assert list(new_env.roles.values())[0].rc.watch == role_c.rc.watch
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -28,9 +28,9 @@ from tests.metagpt.serialize_deserialize.test_serdeser_base import (
def test_roles(context):
role_a = RoleA()
assert len(role_a.rc.watch) == 1
assert len(role_a.rc.watch) == 2
role_b = RoleB()
assert len(role_a.rc.watch) == 1
assert len(role_a.rc.watch) == 2
assert len(role_b.rc.watch) == 1
role_d = RoleD(actions=[ActionOK()])

View file

@ -8,9 +8,9 @@ from typing import Optional
from pydantic import BaseModel, Field
from metagpt.actions import Action, ActionOutput
from metagpt.actions import Action, ActionOutput, UserRequirement
from metagpt.actions.action_node import ActionNode
from metagpt.actions.add_requirement import UserRequirement
from metagpt.actions.fix_bug import FixBug
from metagpt.roles.role import Role, RoleReactMode
serdeser_path = Path(__file__).absolute().parent.joinpath("..", "..", "data", "serdeser_storage")
@ -68,7 +68,7 @@ class RoleA(Role):
def __init__(self, **kwargs):
super(RoleA, self).__init__(**kwargs)
self.set_actions([ActionPass])
self._watch([UserRequirement])
self._watch([FixBug, UserRequirement])
class RoleB(Role):
@ -93,7 +93,7 @@ class RoleC(Role):
def __init__(self, **kwargs):
super(RoleC, self).__init__(**kwargs)
self.set_actions([ActionOK, ActionRaise])
self._watch([UserRequirement])
self._watch([FixBug, UserRequirement])
self.rc.react_mode = RoleReactMode.BY_ORDER
self.rc.memory.ignore_id = True

View file

@ -29,3 +29,7 @@ def div(a: int, b: int = 0):
assert new_action.name == "WriteCodeReview"
await new_action.run()
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -14,8 +14,8 @@ from tests.metagpt.provider.mock_llm_config import mock_llm_config
def test_config_1():
cfg = Config.default()
llm = cfg.get_openai_llm()
assert llm is not None
assert llm.api_type == LLMType.OPENAI
if cfg.llm.api_type == LLMType.OPENAI:
assert llm is not None
def test_config_from_dict():

View file

@ -53,8 +53,8 @@ def test_context_1():
def test_context_2():
ctx = Context()
llm = ctx.config.get_openai_llm()
assert llm is not None
assert llm.api_type == LLMType.OPENAI
if ctx.config.llm.api_type == LLMType.OPENAI:
assert llm is not None
kwargs = ctx.kwargs
assert kwargs is not None

View file

@ -114,7 +114,6 @@ class MockLLM(OriginalLLM):
raise ValueError(
"In current test setting, api call is not allowed, you should properly mock your tests, "
"or add expected api response in tests/data/rsp_cache.json. "
f"The prompt you want for api call: {msg_key}"
)
# Call the original unmocked method
rsp = await ask_func(*args, **kwargs)