Merge branch 'mgx_ops' into role_zero_draft

This commit is contained in:
garylin2099 2024-06-05 15:00:49 +08:00
commit bbddbf4ef0
61 changed files with 1743 additions and 577 deletions

View file

@ -303,5 +303,4 @@ def test_action_node_from_pydantic_and_print_everything():
if __name__ == "__main__":
test_create_model_class()
test_create_model_class_with_mapping()
pytest.main([__file__, "-s"])

View file

@ -6,37 +6,104 @@
@File : test_design_api.py
@Modifiled By: mashenquan, 2023-12-6. According to RFC 135
"""
from pathlib import Path
import pytest
from metagpt.actions.design_api import WriteDesign
from metagpt.llm import LLM
from metagpt.const import DEFAULT_WORKSPACE_ROOT, METAGPT_ROOT
from metagpt.logs import logger
from metagpt.schema import Message
from metagpt.schema import AIMessage, Message
from metagpt.utils.project_repo import ProjectRepo
from tests.data.incremental_dev_project.mock import DESIGN_SAMPLE, REFINED_PRD_JSON
@pytest.mark.asyncio
async def test_design_api(context):
inputs = ["我们需要一个音乐播放器,它应该有播放、暂停、上一曲、下一曲等功能。"] # PRD_SAMPLE
for prd in inputs:
await context.repo.docs.prd.save(filename="new_prd.txt", content=prd)
async def test_design(context):
# Mock new design env
prd = "我们需要一个音乐播放器,它应该有播放、暂停、上一曲、下一曲等功能。"
context.kwargs.project_path = context.config.project_path
context.kwargs.inc = False
filename = "prd.txt"
repo = ProjectRepo(context.kwargs.project_path)
await repo.docs.prd.save(filename=filename, content=prd)
kvs = {
"project_path": str(context.kwargs.project_path),
"changed_prd_filenames": [str(repo.docs.prd.workdir / filename)],
}
instruct_content = AIMessage.create_instruct_value(kvs=kvs, class_name="WritePRDOutput")
design_api = WriteDesign(context=context)
result = await design_api.run(Message(content=prd, instruct_content=None))
logger.info(result)
assert result
@pytest.mark.asyncio
async def test_refined_design_api(context):
await context.repo.docs.prd.save(filename="1.txt", content=str(REFINED_PRD_JSON))
await context.repo.docs.system_design.save(filename="1.txt", content=DESIGN_SAMPLE)
design_api = WriteDesign(context=context, llm=LLM())
result = await design_api.run(Message(content="", instruct_content=None))
design_api = WriteDesign(context=context)
result = await design_api.run([Message(content=prd, instruct_content=instruct_content)])
logger.info(result)
assert result
assert isinstance(result, AIMessage)
assert result.instruct_content
assert repo.docs.system_design.changed_files
# Mock incremental design env
context.kwargs.inc = True
await repo.docs.prd.save(filename=filename, content=str(REFINED_PRD_JSON))
await repo.docs.system_design.save(filename=filename, content=DESIGN_SAMPLE)
result = await design_api.run([Message(content="", instruct_content=instruct_content)])
logger.info(result)
assert result
assert isinstance(result, AIMessage)
assert result.instruct_content
assert repo.docs.system_design.changed_files
@pytest.mark.parametrize(
("user_requirement", "prd_filename", "legacy_design_filename"),
[
("我们需要一个音乐播放器,它应该有播放、暂停、上一曲、下一曲等功能。", None, None),
("write 2048 game", str(METAGPT_ROOT / "tests/data/prd.json"), None),
(
"write 2048 game",
str(METAGPT_ROOT / "tests/data/prd.json"),
str(METAGPT_ROOT / "tests/data/system_design.json"),
),
],
)
@pytest.mark.asyncio
async def test_design_api(context, user_requirement, prd_filename, legacy_design_filename):
action = WriteDesign()
result = await action.run(
user_requirement=user_requirement, prd_filename=prd_filename, legacy_design_filename=legacy_design_filename
)
assert isinstance(result, AIMessage)
assert result.content
assert str(DEFAULT_WORKSPACE_ROOT) in result.content
@pytest.mark.parametrize(
("user_requirement", "prd_filename", "legacy_design_filename"),
[
("我们需要一个音乐播放器,它应该有播放、暂停、上一曲、下一曲等功能。", None, None),
("write 2048 game", str(METAGPT_ROOT / "tests/data/prd.json"), None),
(
"write 2048 game",
str(METAGPT_ROOT / "tests/data/prd.json"),
str(METAGPT_ROOT / "tests/data/system_design.json"),
),
],
)
@pytest.mark.asyncio
async def test_design_api_dir(context, user_requirement, prd_filename, legacy_design_filename):
action = WriteDesign()
result = await action.run(
user_requirement=user_requirement,
prd_filename=prd_filename,
legacy_design_filename=legacy_design_filename,
output_pathname=str(Path(context.config.project_path) / "1.txt"),
)
assert isinstance(result, AIMessage)
assert result.content
assert str(context.config.project_path) in result.content
assert result.instruct_content
assert result.instruct_content.changed_system_design_filenames
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -5,13 +5,15 @@
@Author : alexanderwu
@File : test_project_management.py
"""
import json
import pytest
from metagpt.actions.project_management import WriteTasks
from metagpt.llm import LLM
from metagpt.const import METAGPT_ROOT
from metagpt.logs import logger
from metagpt.schema import Message
from metagpt.schema import AIMessage, Message
from metagpt.utils.project_repo import ProjectRepo
from tests.data.incremental_dev_project.mock import (
REFINED_DESIGN_JSON,
REFINED_PRD_JSON,
@ -22,29 +24,46 @@ from tests.metagpt.actions.mock_json import DESIGN, PRD
@pytest.mark.asyncio
async def test_task(context):
await context.repo.docs.prd.save("1.txt", content=str(PRD))
await context.repo.docs.system_design.save("1.txt", content=str(DESIGN))
logger.info(context.git_repo)
# Mock write tasks env
context.kwargs.project_path = context.config.project_path
context.kwargs.inc = False
repo = ProjectRepo(context.kwargs.project_path)
filename = "1.txt"
await repo.docs.prd.save(filename=filename, content=str(PRD))
await repo.docs.system_design.save(filename=filename, content=str(DESIGN))
kvs = {
"project_path": context.kwargs.project_path,
"changed_system_design_filenames": [str(repo.docs.system_design.workdir / filename)],
}
instruct_content = AIMessage.create_instruct_value(kvs=kvs, class_name="WriteDesignOutput")
action = WriteTasks(context=context)
result = await action.run(Message(content="", instruct_content=None))
result = await action.run([Message(content="", instruct_content=instruct_content)])
logger.info(result)
assert result
assert result.instruct_content.changed_task_filenames
# Mock incremental env
context.kwargs.inc = True
await repo.docs.prd.save(filename=filename, content=str(REFINED_PRD_JSON))
await repo.docs.system_design.save(filename=filename, content=str(REFINED_DESIGN_JSON))
await repo.docs.task.save(filename=filename, content=TASK_SAMPLE)
result = await action.run([Message(content="", instruct_content=instruct_content)])
logger.info(result)
assert result
assert result.instruct_content.changed_task_filenames
@pytest.mark.asyncio
async def test_refined_task(context):
await context.repo.docs.prd.save("2.txt", content=str(REFINED_PRD_JSON))
await context.repo.docs.system_design.save("2.txt", content=str(REFINED_DESIGN_JSON))
await context.repo.docs.task.save("2.txt", content=TASK_SAMPLE)
logger.info(context.git_repo)
action = WriteTasks(context=context, llm=LLM())
result = await action.run(Message(content="", instruct_content=None))
logger.info(result)
async def test_task_api(context):
action = WriteTasks()
result = await action.run(design_filename=str(METAGPT_ROOT / "tests/data/system_design.json"))
assert result
assert result.content
m = json.loads(result.content)
assert m
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -26,12 +26,7 @@ from tests.metagpt.actions.mock_markdown import TASKS_2, WRITE_CODE_PROMPT_SAMPL
def setup_inc_workdir(context, inc: bool = False):
"""setup incremental workdir for testing"""
context.src_workspace = context.git_repo.workdir / "src"
if inc:
context.config.inc = inc
context.repo.old_workspace = context.repo.git_repo.workdir / "old"
context.config.project_path = "old"
context.config.inc = inc
return context

View file

@ -6,25 +6,26 @@
@File : test_write_prd.py
@Modified By: mashenquan, 2023-11-1. According to Chapter 2.2.1 and 2.2.2 of RFC 116, replace `handle` with `run`.
"""
import uuid
from pathlib import Path
import pytest
from metagpt.actions import UserRequirement, WritePRD
from metagpt.const import REQUIREMENT_FILENAME
from metagpt.const import DEFAULT_WORKSPACE_ROOT, REQUIREMENT_FILENAME
from metagpt.logs import logger
from metagpt.roles.product_manager import ProductManager
from metagpt.roles.role import RoleReactMode
from metagpt.schema import Message
from metagpt.schema import AIMessage, Message
from metagpt.utils.common import any_to_str
from tests.data.incremental_dev_project.mock import NEW_REQUIREMENT_SAMPLE, PRD_SAMPLE
from tests.metagpt.actions.test_write_code import setup_inc_workdir
from metagpt.utils.project_repo import ProjectRepo
from tests.data.incremental_dev_project.mock import NEW_REQUIREMENT_SAMPLE
@pytest.mark.asyncio
async def test_write_prd(new_filename, context):
product_manager = ProductManager(context=context)
requirements = "开发一个基于大语言模型与私有知识库的搜索引擎,希望可以基于大语言模型进行搜索总结"
await context.repo.docs.save(filename=REQUIREMENT_FILENAME, content=requirements)
product_manager.rc.react_mode = RoleReactMode.BY_ORDER
prd = await product_manager.run(Message(content=requirements, cause_by=UserRequirement))
assert prd.cause_by == any_to_str(WritePRD)
@ -34,38 +35,39 @@ async def test_write_prd(new_filename, context):
# Assert the prd is not None or empty
assert prd is not None
assert prd.content != ""
assert product_manager.context.repo.docs.prd.changed_files
repo = ProjectRepo(context.kwargs.project_path)
assert repo.docs.prd.changed_files
repo.git_repo.archive()
@pytest.mark.asyncio
async def test_write_prd_inc(new_filename, context, git_dir):
context = setup_inc_workdir(context, inc=True)
await context.repo.docs.prd.save("1.txt", PRD_SAMPLE)
await context.repo.docs.save(filename=REQUIREMENT_FILENAME, content=NEW_REQUIREMENT_SAMPLE)
# Mock incremental requirement
context.config.inc = True
context.config.project_path = context.kwargs.project_path
repo = ProjectRepo(context.config.project_path)
await repo.docs.save(filename=REQUIREMENT_FILENAME, content=NEW_REQUIREMENT_SAMPLE)
action = WritePRD(context=context)
prd = await action.run(Message(content=NEW_REQUIREMENT_SAMPLE, instruct_content=None))
prd = await action.run([Message(content=NEW_REQUIREMENT_SAMPLE, instruct_content=None)])
logger.info(NEW_REQUIREMENT_SAMPLE)
logger.info(prd)
# Assert the prd is not None or empty
assert prd is not None
assert prd.content != ""
assert "Refined Requirements" in prd.content
assert repo.git_repo.changed_files
@pytest.mark.asyncio
async def test_fix_debug(new_filename, context, git_dir):
context.src_workspace = context.git_repo.workdir / context.git_repo.workdir.name
# Mock legacy project
context.kwargs.project_path = str(git_dir)
repo = ProjectRepo(context.kwargs.project_path)
repo.with_src_path(git_dir.name)
await repo.srcs.save(filename="main.py", content='if __name__ == "__main__":\nmain()')
requirements = "ValueError: undefined variable `st`."
await repo.docs.save(filename=REQUIREMENT_FILENAME, content=requirements)
await context.repo.with_src_path(context.src_workspace).srcs.save(
filename="main.py", content='if __name__ == "__main__":\nmain()'
)
requirements = "Please fix the bug in the code."
await context.repo.docs.save(filename=REQUIREMENT_FILENAME, content=requirements)
action = WritePRD(context=context)
prd = await action.run(Message(content=requirements, instruct_content=None))
prd = await action.run([Message(content=requirements, instruct_content=None)])
logger.info(prd)
# Assert the prd is not None or empty
@ -73,5 +75,40 @@ async def test_fix_debug(new_filename, context, git_dir):
assert prd.content != ""
@pytest.mark.asyncio
async def test_write_prd_api(context):
action = WritePRD()
result = await action.run(user_requirement="write a snake game.")
assert isinstance(result, AIMessage)
assert result.content
assert str(DEFAULT_WORKSPACE_ROOT) in result.content
result = await action.run(
user_requirement="write a snake game.",
output_pathname=str(Path(context.config.project_path) / f"{uuid.uuid4().hex}.json"),
)
assert isinstance(result, AIMessage)
assert result.content
assert result.instruct_content
assert str(context.config.project_path) in result.content
legacy_prd_filename = result.instruct_content.changed_prd_filenames[-1]
result = await action.run(user_requirement="Add moving enemy.", legacy_prd_filename=legacy_prd_filename)
assert isinstance(result, AIMessage)
assert result.content
assert str(DEFAULT_WORKSPACE_ROOT) in result.content
result = await action.run(
user_requirement="Add moving enemy.",
output_pathname=str(Path(context.config.project_path) / f"{uuid.uuid4().hex}.json"),
legacy_prd_filename=legacy_prd_filename,
)
assert isinstance(result, AIMessage)
assert result.content
assert result.instruct_content
assert str(context.config.project_path) in result.content
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -25,10 +25,6 @@ class TestSimpleEngine:
def mock_simple_directory_reader(self, mocker):
return mocker.patch("metagpt.rag.engines.simple.SimpleDirectoryReader")
@pytest.fixture
def mock_vector_store_index(self, mocker):
return mocker.patch("metagpt.rag.engines.simple.VectorStoreIndex.from_documents")
@pytest.fixture
def mock_get_retriever(self, mocker):
return mocker.patch("metagpt.rag.engines.simple.get_retriever")
@ -45,7 +41,6 @@ class TestSimpleEngine:
self,
mocker,
mock_simple_directory_reader,
mock_vector_store_index,
mock_get_retriever,
mock_get_rankers,
mock_get_response_synthesizer,
@ -81,11 +76,8 @@ class TestSimpleEngine:
# Assert
mock_simple_directory_reader.assert_called_once_with(input_dir=input_dir, input_files=input_files)
mock_vector_store_index.assert_called_once()
mock_get_retriever.assert_called_once_with(
configs=retriever_configs, index=mock_vector_store_index.return_value
)
mock_get_rankers.assert_called_once_with(configs=ranker_configs, llm=llm)
mock_get_retriever.assert_called_once()
mock_get_rankers.assert_called_once()
mock_get_response_synthesizer.assert_called_once_with(llm=llm)
assert isinstance(engine, SimpleEngine)
@ -119,7 +111,7 @@ class TestSimpleEngine:
# Assert
assert isinstance(engine, SimpleEngine)
assert engine.index is not None
assert engine._transformations is not None
def test_from_objs_with_bm25_config(self):
# Setup
@ -137,6 +129,7 @@ class TestSimpleEngine:
def test_from_index(self, mocker, mock_llm, mock_embedding):
# Mock
mock_index = mocker.MagicMock(spec=VectorStoreIndex)
mock_index.as_retriever.return_value = "retriever"
mock_get_index = mocker.patch("metagpt.rag.engines.simple.get_index")
mock_get_index.return_value = mock_index
@ -149,7 +142,7 @@ class TestSimpleEngine:
# Assert
assert isinstance(engine, SimpleEngine)
assert engine.index is mock_index
assert engine._retriever == "retriever"
@pytest.mark.asyncio
async def test_asearch(self, mocker):
@ -200,14 +193,11 @@ class TestSimpleEngine:
mock_retriever = mocker.MagicMock(spec=ModifiableRAGRetriever)
mock_index = mocker.MagicMock(spec=VectorStoreIndex)
mock_index._transformations = mocker.MagicMock()
mock_run_transformations = mocker.patch("metagpt.rag.engines.simple.run_transformations")
mock_run_transformations.return_value = ["node1", "node2"]
# Setup
engine = SimpleEngine(retriever=mock_retriever, index=mock_index)
engine = SimpleEngine(retriever=mock_retriever)
input_files = ["test_file1", "test_file2"]
# Exec
@ -230,7 +220,7 @@ class TestSimpleEngine:
return ""
objs = [CustomTextNode(text=f"text_{i}", metadata={"obj": f"obj_{i}"}) for i in range(2)]
engine = SimpleEngine(retriever=mock_retriever, index=mocker.MagicMock())
engine = SimpleEngine(retriever=mock_retriever)
# Exec
engine.add_objs(objs=objs)

View file

@ -97,6 +97,5 @@ class TestConfigBasedFactory:
def test_val_from_config_or_kwargs_key_error(self):
# Test KeyError when the key is not found in both config object and kwargs
config = DummyConfig(name=None)
with pytest.raises(KeyError) as exc_info:
ConfigBasedFactory._val_from_config_or_kwargs("missing_key", config)
assert "The key 'missing_key' is required but not provided" in str(exc_info.value)
val = ConfigBasedFactory._val_from_config_or_kwargs("missing_key", config)
assert val is None

View file

@ -1,5 +1,6 @@
import pytest
from metagpt.configs.embedding_config import EmbeddingType
from metagpt.configs.llm_config import LLMType
from metagpt.rag.factories.embedding import RAGEmbeddingFactory
@ -10,30 +11,51 @@ class TestRAGEmbeddingFactory:
self.embedding_factory = RAGEmbeddingFactory()
@pytest.fixture
def mock_openai_embedding(self, mocker):
def mock_config(self, mocker):
return mocker.patch("metagpt.rag.factories.embedding.config")
@staticmethod
def mock_openai_embedding(mocker):
return mocker.patch("metagpt.rag.factories.embedding.OpenAIEmbedding")
@pytest.fixture
def mock_azure_embedding(self, mocker):
@staticmethod
def mock_azure_embedding(mocker):
return mocker.patch("metagpt.rag.factories.embedding.AzureOpenAIEmbedding")
def test_get_rag_embedding_openai(self, mock_openai_embedding):
# Exec
self.embedding_factory.get_rag_embedding(LLMType.OPENAI)
@staticmethod
def mock_gemini_embedding(mocker):
return mocker.patch("metagpt.rag.factories.embedding.GeminiEmbedding")
# Assert
mock_openai_embedding.assert_called_once()
@staticmethod
def mock_ollama_embedding(mocker):
return mocker.patch("metagpt.rag.factories.embedding.OllamaEmbedding")
def test_get_rag_embedding_azure(self, mock_azure_embedding):
# Exec
self.embedding_factory.get_rag_embedding(LLMType.AZURE)
# Assert
mock_azure_embedding.assert_called_once()
def test_get_rag_embedding_default(self, mocker, mock_openai_embedding):
@pytest.mark.parametrize(
("mock_func", "embedding_type"),
[
(mock_openai_embedding, LLMType.OPENAI),
(mock_azure_embedding, LLMType.AZURE),
(mock_openai_embedding, EmbeddingType.OPENAI),
(mock_azure_embedding, EmbeddingType.AZURE),
(mock_gemini_embedding, EmbeddingType.GEMINI),
(mock_ollama_embedding, EmbeddingType.OLLAMA),
],
)
def test_get_rag_embedding(self, mock_func, embedding_type, mocker):
# Mock
mock_config = mocker.patch("metagpt.rag.factories.embedding.config")
mock = mock_func(mocker)
# Exec
self.embedding_factory.get_rag_embedding(embedding_type)
# Assert
mock.assert_called_once()
def test_get_rag_embedding_default(self, mocker, mock_config):
# Mock
mock_openai_embedding = self.mock_openai_embedding(mocker)
mock_config.embedding.api_type = None
mock_config.llm.api_type = LLMType.OPENAI
# Exec
@ -41,3 +63,44 @@ class TestRAGEmbeddingFactory:
# Assert
mock_openai_embedding.assert_called_once()
@pytest.mark.parametrize(
"model, embed_batch_size, expected_params",
[("test_model", 100, {"model_name": "test_model", "embed_batch_size": 100}), (None, None, {})],
)
def test_try_set_model_and_batch_size(self, mock_config, model, embed_batch_size, expected_params):
# Mock
mock_config.embedding.model = model
mock_config.embedding.embed_batch_size = embed_batch_size
# Setup
test_params = {}
# Exec
self.embedding_factory._try_set_model_and_batch_size(test_params)
# Assert
assert test_params == expected_params
def test_resolve_embedding_type(self, mock_config):
# Mock
mock_config.embedding.api_type = EmbeddingType.OPENAI
# Exec
embedding_type = self.embedding_factory._resolve_embedding_type()
# Assert
assert embedding_type == EmbeddingType.OPENAI
def test_resolve_embedding_type_exception(self, mock_config):
# Mock
mock_config.embedding.api_type = None
mock_config.llm.api_type = LLMType.GEMINI
# Assert
with pytest.raises(TypeError):
self.embedding_factory._resolve_embedding_type()
def test_raise_for_key(self):
with pytest.raises(ValueError):
self.embedding_factory._raise_for_key("key")

View file

@ -1,6 +1,8 @@
import faiss
import pytest
from llama_index.core import VectorStoreIndex
from llama_index.core.embeddings import MockEmbedding
from llama_index.core.schema import TextNode
from llama_index.vector_stores.chroma import ChromaVectorStore
from llama_index.vector_stores.elasticsearch import ElasticsearchStore
@ -43,6 +45,14 @@ class TestRetrieverFactory:
def mock_es_vector_store(self, mocker):
return mocker.MagicMock(spec=ElasticsearchStore)
@pytest.fixture
def mock_nodes(self, mocker):
return [TextNode(text="msg")]
@pytest.fixture
def mock_embedding(self):
return MockEmbedding(embed_dim=1)
def test_get_retriever_with_faiss_config(self, mock_faiss_index, mocker, mock_vector_store_index):
mock_config = FAISSRetrieverConfig(dimensions=128)
mocker.patch("faiss.IndexFlatL2", return_value=mock_faiss_index)
@ -52,42 +62,40 @@ class TestRetrieverFactory:
assert isinstance(retriever, FAISSRetriever)
def test_get_retriever_with_bm25_config(self, mocker, mock_vector_store_index):
def test_get_retriever_with_bm25_config(self, mocker, mock_nodes):
mock_config = BM25RetrieverConfig()
mocker.patch("rank_bm25.BM25Okapi.__init__", return_value=None)
mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index)
retriever = self.retriever_factory.get_retriever(configs=[mock_config])
retriever = self.retriever_factory.get_retriever(configs=[mock_config], nodes=mock_nodes)
assert isinstance(retriever, DynamicBM25Retriever)
def test_get_retriever_with_multiple_configs_returns_hybrid(self, mocker, mock_vector_store_index):
mock_faiss_config = FAISSRetrieverConfig(dimensions=128)
def test_get_retriever_with_multiple_configs_returns_hybrid(self, mocker, mock_nodes, mock_embedding):
mock_faiss_config = FAISSRetrieverConfig(dimensions=1)
mock_bm25_config = BM25RetrieverConfig()
mocker.patch("rank_bm25.BM25Okapi.__init__", return_value=None)
mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index)
retriever = self.retriever_factory.get_retriever(configs=[mock_faiss_config, mock_bm25_config])
retriever = self.retriever_factory.get_retriever(
configs=[mock_faiss_config, mock_bm25_config], nodes=mock_nodes, embed_model=mock_embedding
)
assert isinstance(retriever, SimpleHybridRetriever)
def test_get_retriever_with_chroma_config(self, mocker, mock_vector_store_index, mock_chroma_vector_store):
def test_get_retriever_with_chroma_config(self, mocker, mock_chroma_vector_store, mock_embedding):
mock_config = ChromaRetrieverConfig(persist_path="/path/to/chroma", collection_name="test_collection")
mock_chromadb = mocker.patch("metagpt.rag.factories.retriever.chromadb.PersistentClient")
mock_chromadb.get_or_create_collection.return_value = mocker.MagicMock()
mocker.patch("metagpt.rag.factories.retriever.ChromaVectorStore", return_value=mock_chroma_vector_store)
mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index)
retriever = self.retriever_factory.get_retriever(configs=[mock_config])
retriever = self.retriever_factory.get_retriever(configs=[mock_config], nodes=[], embed_model=mock_embedding)
assert isinstance(retriever, ChromaRetriever)
def test_get_retriever_with_es_config(self, mocker, mock_vector_store_index, mock_es_vector_store):
def test_get_retriever_with_es_config(self, mocker, mock_es_vector_store, mock_embedding):
mock_config = ElasticsearchRetrieverConfig(store_config=ElasticsearchStoreConfig())
mocker.patch("metagpt.rag.factories.retriever.ElasticsearchStore", return_value=mock_es_vector_store)
mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index)
retriever = self.retriever_factory.get_retriever(configs=[mock_config])
retriever = self.retriever_factory.get_retriever(configs=[mock_config], nodes=[], embed_model=mock_embedding)
assert isinstance(retriever, ElasticsearchRetriever)
@ -111,3 +119,19 @@ class TestRetrieverFactory:
extracted_index = self.retriever_factory._extract_index(index=mock_vector_store_index)
assert extracted_index == mock_vector_store_index
def test_get_or_build_when_get(self, mocker):
want = "existing_index"
mocker.patch.object(self.retriever_factory, "_extract_index", return_value=want)
got = self.retriever_factory._build_es_index(None)
assert got == want
def test_get_or_build_when_build(self, mocker):
want = "call_build_es_index"
mocker.patch.object(self.retriever_factory, "_build_es_index", return_value=want)
got = self.retriever_factory._build_es_index(None)
assert got == want

View file

@ -392,5 +392,11 @@ async def test_parse_resources(context, content: str, key_descriptions):
assert k in result
@pytest.mark.parametrize(("name", "value"), [("c1", {"age": 10, "name": "Alice"}), ("", {"path": __file__})])
def test_create_instruct_value(name, value):
obj = Message.create_instruct_value(kvs=value, class_name=name)
assert obj.model_dump() == value
if __name__ == "__main__":
pytest.main([__file__, "-s"])