mock writeprd new filename, improve cache usefulness

This commit is contained in:
yzlin 2024-01-04 20:45:38 +08:00
parent 271ecc30a2
commit 106543b3ca
10 changed files with 103 additions and 91 deletions

View file

@ -37,7 +37,7 @@ jobs:
path: |
./unittest.txt
./htmlcov/
./tests/data/rsp_cache.json
./tests/data/rsp_cache_new.json
retention-days: 3
if: ${{ always() }}

View file

@ -22,12 +22,14 @@ from metagpt.logs import logger
from metagpt.utils.git_repository import GitRepository
from tests.mock.mock_llm import MockLLM
RSP_CACHE_NEW = {} # used globally for producing new and useful only response cache
@pytest.fixture(scope="session")
def rsp_cache():
# model_version = CONFIG.openai_api_model
rsp_cache_file_path = TEST_DATA_PATH / "rsp_cache.json" # read repo-provided
# new_rsp_cache_file_path = TEST_DATA_PATH / "rsp_cache_new.json" # exporting a new copy
new_rsp_cache_file_path = TEST_DATA_PATH / "rsp_cache_new.json" # exporting a new copy
if os.path.exists(rsp_cache_file_path):
with open(rsp_cache_file_path, "r") as f1:
rsp_cache_json = json.load(f1)
@ -36,6 +38,8 @@ def rsp_cache():
yield rsp_cache_json
with open(rsp_cache_file_path, "w") as f2:
json.dump(rsp_cache_json, f2, indent=4, ensure_ascii=False)
with open(new_rsp_cache_file_path, "w") as f2:
json.dump(RSP_CACHE_NEW, f2, indent=4, ensure_ascii=False)
# Hook to capture the test result
@ -57,7 +61,12 @@ def llm_mock(rsp_cache, mocker, request):
if hasattr(request.node, "test_outcome") and request.node.test_outcome.passed:
if llm.rsp_candidates:
for rsp_candidate in llm.rsp_candidates:
llm.rsp_cache.update(rsp_candidate)
cand_key = list(rsp_candidate.keys())[0]
cand_value = list(rsp_candidate.values())[0]
if cand_key not in llm.rsp_cache:
logger.info(f"Added '{cand_key[:100]} ... -> {cand_value[:20]} ...' to response cache")
llm.rsp_cache.update(rsp_candidate)
RSP_CACHE_NEW.update(rsp_candidate)
class Context:
@ -142,6 +151,12 @@ def init_config():
Config()
@pytest.fixture(scope="function")
def new_filename(mocker):
mocker.patch("metagpt.utils.file_repository.FileRepository.new_filename", lambda: "20240101")
yield mocker
@pytest.fixture
def aiohttp_mocker(mocker):
class MockAioResponse:

File diff suppressed because one or more lines are too long

View file

@ -18,7 +18,7 @@ from metagpt.utils.file_repository import FileRepository
@pytest.mark.asyncio
async def test_write_prd():
async def test_write_prd(new_filename):
product_manager = ProductManager()
requirements = "开发一个基于大语言模型与私有知识库的搜索引擎,希望可以基于大语言模型进行搜索总结"
await FileRepository.save_file(filename=REQUIREMENT_FILENAME, content=requirements, relative_path=DOCS_FILE_REPO)

View file

@ -13,7 +13,7 @@ from tests.metagpt.roles.mock import MockMessages
@pytest.mark.asyncio
async def test_product_manager():
async def test_product_manager(new_filename):
product_manager = ProductManager()
rsp = await product_manager.run(MockMessages.req)
logger.info(rsp)

View file

@ -10,7 +10,7 @@ from metagpt.schema import Message
@pytest.mark.asyncio
async def test_product_manager_deserialize():
async def test_product_manager_deserialize(new_filename):
role = ProductManager()
ser_role_dict = role.model_dump(by_alias=True)
new_role = ProductManager(**ser_role_dict)

View file

@ -9,7 +9,7 @@ from metagpt.actions import WritePRD
from metagpt.schema import Message
def test_action_serialize():
def test_action_serialize(new_filename):
action = WritePRD()
ser_action_dict = action.model_dump()
assert "name" in ser_action_dict
@ -17,7 +17,7 @@ def test_action_serialize():
@pytest.mark.asyncio
async def test_action_deserialize():
async def test_action_deserialize(new_filename):
action = WritePRD()
serialized_data = action.model_dump()
new_action = WritePRD(**serialized_data)

View file

@ -45,7 +45,7 @@ def test_get_roles(env: Environment):
@pytest.mark.asyncio
async def test_publish_and_process_message(env: Environment):
async def test_publish_and_process_message(env: Environment, new_filename):
if CONFIG.git_repo:
CONFIG.git_repo.delete_repository()
CONFIG.git_repo = None

View file

@ -16,14 +16,14 @@ runner = CliRunner()
@pytest.mark.asyncio
async def test_empty_team():
async def test_empty_team(new_filename):
# FIXME: we're now using "metagpt" cli, so the entrance should be replaced instead.
company = Team()
history = await company.run(idea="Build a simple search system. I will upload my files later.")
logger.info(history)
def test_startup():
def test_startup(new_filename):
args = ["Make a cli snake game"]
result = runner.invoke(app, args)
logger.info(result)

View file

@ -65,24 +65,26 @@ class MockLLM(OpenAILLM):
timeout=3,
stream=True,
) -> str:
if msg not in self.rsp_cache:
msg_key = msg # used to identify it a message has been called before
if system_msgs:
joined_system_msg = "#MSG_SEP#".join(system_msgs) + "#SYSTEM_MSG_END#"
msg_key = joined_system_msg + msg_key
if msg_key not in self.rsp_cache:
# Call the original unmocked method
rsp = await self.original_aask(msg, system_msgs, format_msgs, timeout, stream)
logger.info(f"Added '{rsp[:20]} ...' to response cache")
self.rsp_candidates.append({msg: rsp})
return rsp
else:
logger.warning("Use response cache")
return self.rsp_cache[msg]
rsp = self.rsp_cache[msg_key]
self.rsp_candidates.append({msg_key: rsp})
return rsp
async def aask_batch(self, msgs: list, timeout=3) -> str:
joined_msgs = "#MSG_SEP#".join([msg if isinstance(msg, str) else msg.content for msg in msgs])
if joined_msgs not in self.rsp_cache:
msg_key = "#MSG_SEP#".join([msg if isinstance(msg, str) else msg.content for msg in msgs])
if msg_key not in self.rsp_cache:
# Call the original unmocked method
rsp = await self.original_aask_batch(msgs, timeout)
logger.info(f"Added '{joined_msgs[:20]} ...' to response cache")
self.rsp_candidates.append({joined_msgs: rsp})
return rsp
else:
logger.warning("Use response cache")
return self.rsp_cache[joined_msgs]
rsp = self.rsp_cache[msg_key]
self.rsp_candidates.append({msg_key: rsp})
return rsp