MetaGPT/tests/conftest.py
2024-01-31 00:08:09 +08:00

233 lines
7.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/5/1 12:10
@Author : alexanderwu
@File : conftest.py
"""
import asyncio
import json
import logging
import os
import re
import uuid
from typing import Callable
import pytest
from metagpt.const import DEFAULT_WORKSPACE_ROOT, TEST_DATA_PATH
from metagpt.context import Context as MetagptContext
from metagpt.llm import LLM
from metagpt.logs import logger
from metagpt.utils.git_repository import GitRepository
from metagpt.utils.project_repo import ProjectRepo
from tests.mock.mock_aiohttp import MockAioResponse
from tests.mock.mock_curl_cffi import MockCurlCffiResponse
from tests.mock.mock_httplib2 import MockHttplib2Response
from tests.mock.mock_llm import MockLLM
RSP_CACHE_NEW = {} # used globally for producing new and useful only response cache
ALLOW_OPENAI_API_CALL = int(
os.environ.get("ALLOW_OPENAI_API_CALL", 1)
) # NOTE: should change to default 0 (False) once mock is complete
@pytest.fixture(scope="session")
def rsp_cache():
rsp_cache_file_path = TEST_DATA_PATH / "rsp_cache.json" # read repo-provided
new_rsp_cache_file_path = TEST_DATA_PATH / "rsp_cache_new.json" # exporting a new copy
if os.path.exists(rsp_cache_file_path):
with open(rsp_cache_file_path, "r", encoding="utf-8") as f1:
rsp_cache_json = json.load(f1)
else:
rsp_cache_json = {}
yield rsp_cache_json
with open(rsp_cache_file_path, "w", encoding="utf-8") as f2:
json.dump(rsp_cache_json, f2, indent=4, ensure_ascii=False)
with open(new_rsp_cache_file_path, "w", encoding="utf-8") as f2:
json.dump(RSP_CACHE_NEW, f2, indent=4, ensure_ascii=False)
# Hook to capture the test result
@pytest.hookimpl(tryfirst=True, hookwrapper=True)
def pytest_runtest_makereport(item, call):
outcome = yield
rep = outcome.get_result()
if rep.when == "call":
item.test_outcome = rep
@pytest.fixture(scope="function", autouse=True)
def llm_mock(rsp_cache, mocker, request):
llm = MockLLM(allow_open_api_call=ALLOW_OPENAI_API_CALL)
llm.rsp_cache = rsp_cache
mocker.patch("metagpt.provider.base_llm.BaseLLM.aask", llm.aask)
mocker.patch("metagpt.provider.base_llm.BaseLLM.aask_batch", llm.aask_batch)
mocker.patch("metagpt.provider.openai_api.OpenAILLM.aask_code", llm.aask_code)
yield mocker
if hasattr(request.node, "test_outcome") and request.node.test_outcome.passed:
if llm.rsp_candidates:
for rsp_candidate in llm.rsp_candidates:
cand_key = list(rsp_candidate.keys())[0]
cand_value = list(rsp_candidate.values())[0]
if cand_key not in llm.rsp_cache:
logger.info(f"Added '{cand_key[:100]} ... -> {str(cand_value)[:20]} ...' to response cache")
llm.rsp_cache.update(rsp_candidate)
RSP_CACHE_NEW.update(rsp_candidate)
class Context:
def __init__(self):
self._llm_ui = None
self._llm_api = LLM()
@property
def llm_api(self):
# 1. 初始化llm带有缓存结果
# 2. 如果缓存query那么直接返回缓存结果
# 3. 如果没有缓存query那么调用llm_api返回结果
# 4. 如果有缓存query那么更新缓存结果
return self._llm_api
@pytest.fixture(scope="package")
def llm_api():
logger.info("Setting up the test")
g_context = Context()
yield g_context.llm_api
logger.info("Tearing down the test")
@pytest.fixture
def proxy():
pattern = re.compile(
rb"(?P<method>[a-zA-Z]+) (?P<uri>(\w+://)?(?P<host>[^\s\'\"<>\[\]{}|/:]+)(:(?P<port>\d+))?[^\s\'\"<>\[\]{}|]*) "
)
async def pipe(reader, writer):
while not reader.at_eof():
writer.write(await reader.read(2048))
writer.close()
async def handle_client(reader, writer):
data = await reader.readuntil(b"\r\n\r\n")
print(f"Proxy: {data}") # checking with capfd fixture
infos = pattern.match(data)
host, port = infos.group("host"), infos.group("port")
port = int(port) if port else 80
remote_reader, remote_writer = await asyncio.open_connection(host, port)
if data.startswith(b"CONNECT"):
writer.write(b"HTTP/1.1 200 Connection Established\r\n\r\n")
else:
remote_writer.write(data)
await asyncio.gather(pipe(reader, remote_writer), pipe(remote_reader, writer))
async def proxy_func():
server = await asyncio.start_server(handle_client, "127.0.0.1", 0)
return server, "http://{}:{}".format(*server.sockets[0].getsockname())
return proxy_func
# see https://github.com/Delgan/loguru/issues/59#issuecomment-466591978
@pytest.fixture
def loguru_caplog(caplog):
class PropogateHandler(logging.Handler):
def emit(self, record):
logging.getLogger(record.name).handle(record)
logger.add(PropogateHandler(), format="{message}")
yield caplog
@pytest.fixture(scope="function")
def context(request):
ctx = MetagptContext()
ctx.git_repo = GitRepository(local_path=DEFAULT_WORKSPACE_ROOT / f"unittest/{uuid.uuid4().hex}")
ctx.repo = ProjectRepo(ctx.git_repo)
# Destroy git repo at the end of the test session.
def fin():
if ctx.git_repo:
ctx.git_repo.delete_repository()
# Register the function for destroying the environment.
request.addfinalizer(fin)
return ctx
@pytest.fixture(scope="session", autouse=True)
def init_config():
pass
@pytest.fixture(scope="function")
def new_filename(mocker):
# NOTE: Mock new filename to make reproducible llm aask, should consider changing after implementing requirement segmentation
mocker.patch("metagpt.utils.file_repository.FileRepository.new_filename", lambda: "20240101")
yield mocker
@pytest.fixture(scope="session")
def search_rsp_cache():
rsp_cache_file_path = TEST_DATA_PATH / "search_rsp_cache.json" # read repo-provided
if os.path.exists(rsp_cache_file_path):
with open(rsp_cache_file_path, "r") as f1:
rsp_cache_json = json.load(f1)
else:
rsp_cache_json = {}
yield rsp_cache_json
with open(rsp_cache_file_path, "w") as f2:
json.dump(rsp_cache_json, f2, indent=4, ensure_ascii=False)
@pytest.fixture
def aiohttp_mocker(mocker):
MockResponse = type("MockResponse", (MockAioResponse,), {})
def wrap(method):
def run(self, url, **kwargs):
return MockResponse(self, method, url, **kwargs)
return run
mocker.patch("aiohttp.ClientSession.request", MockResponse)
for i in ["get", "post", "delete", "patch"]:
mocker.patch(f"aiohttp.ClientSession.{i}", wrap(i))
yield MockResponse
@pytest.fixture
def curl_cffi_mocker(mocker):
MockResponse = type("MockResponse", (MockCurlCffiResponse,), {})
def request(self, *args, **kwargs):
return MockResponse(self, *args, **kwargs)
mocker.patch("curl_cffi.requests.Session.request", request)
yield MockResponse
@pytest.fixture
def httplib2_mocker(mocker):
MockResponse = type("MockResponse", (MockHttplib2Response,), {})
def request(self, *args, **kwargs):
return MockResponse(self, *args, **kwargs)
mocker.patch("httplib2.Http.request", request)
yield MockResponse
@pytest.fixture
def search_engine_mocker(aiohttp_mocker, curl_cffi_mocker, httplib2_mocker, search_rsp_cache):
# aiohttp_mocker: serpapi/serper
# httplib2_mocker: google
# curl_cffi_mocker: ddg
check_funcs: dict[tuple[str, str], Callable[[dict], str]] = {}
aiohttp_mocker.rsp_cache = httplib2_mocker.rsp_cache = curl_cffi_mocker.rsp_cache = search_rsp_cache
aiohttp_mocker.check_funcs = httplib2_mocker.check_funcs = curl_cffi_mocker.check_funcs = check_funcs
yield check_funcs