add mockllm

This commit is contained in:
yzlin 2024-01-02 23:07:50 +08:00
parent 80800d67d0
commit 9564975541
2 changed files with 72 additions and 8 deletions

View file

@ -9,17 +9,84 @@
import asyncio
import logging
import re
from unittest.mock import Mock
import json
from typing import Optional
import os
import pytest
from metagpt.config import CONFIG, Config
from metagpt.const import DEFAULT_WORKSPACE_ROOT
from metagpt.const import DEFAULT_WORKSPACE_ROOT, TEST_DATA_PATH
from metagpt.llm import LLM
from metagpt.provider.openai_api import OpenAILLM
from metagpt.logs import logger
from metagpt.utils.git_repository import GitRepository
class MockLLM(OpenAILLM):
rsp_cache: dict = {}
async def original_aask(
self,
msg: str,
system_msgs: Optional[list[str]] = None,
format_msgs: Optional[list[dict[str, str]]] = None,
timeout=3,
stream=True,
):
"""A copy of metagpt.provider.base_llm.BaseLLM.aask, we can't use super().aask because it will be mocked"""
if system_msgs:
message = self._system_msgs(system_msgs)
else:
message = [self._default_system_msg()] if self.use_system_prompt else []
if format_msgs:
message.extend(format_msgs)
message.append(self._user_msg(msg))
rsp = await self.acompletion_text(message, stream=stream, timeout=timeout)
return rsp
async def aask(
self,
msg: str,
system_msgs: Optional[list[str]] = None,
format_msgs: Optional[list[dict[str, str]]] = None,
timeout=3,
stream=True,
) -> str:
if msg not in self.rsp_cache:
# Call the original unmocked method
rsp = await self.original_aask(msg, system_msgs, format_msgs, timeout, stream)
logger.info(f"added '{rsp[:10]}' ... to response cache")
self.rsp_cache[msg] = rsp
return rsp
else:
logger.info("use response cache")
return self.rsp_cache[msg]
@pytest.fixture(scope="session")
def rsp_cache():
model_version = CONFIG.openai_api_model
rsp_cache_file_path = TEST_DATA_PATH / f"rsp_cache_{model_version}.json" # read repo-provided
new_rsp_cache_file_path = TEST_DATA_PATH / f"rsp_cache_new.json" # exporting a new copy
if os.path.exists(rsp_cache_file_path):
with open(rsp_cache_file_path, "r") as f1:
rsp_cache_json = json.load(f1)
else:
rsp_cache_json = {}
yield rsp_cache_json
with open(new_rsp_cache_file_path, "w") as f2:
json.dump(rsp_cache_json, f2, indent=4, ensure_ascii=False)
@pytest.fixture(scope="function")
def llm_mock(rsp_cache, mocker):
llm = MockLLM()
llm.rsp_cache = rsp_cache
mocker.patch("metagpt.provider.base_llm.BaseLLM.aask", llm.aask)
yield mocker
class Context:
def __init__(self):
self._llm_ui = None
@ -40,12 +107,6 @@ def llm_api():
logger.info("Tearing down the test")
@pytest.fixture(scope="function")
def mock_llm():
# Create a mock LLM for testing
return Mock()
@pytest.fixture(scope="session")
def proxy():
pattern = re.compile(