Merge branch 'dev' into code_intepreter

This commit is contained in:
yzlin 2024-01-31 00:08:09 +08:00
commit 2fcb2a1cfe
282 changed files with 6993 additions and 3210 deletions

View file

@ -0,0 +1,41 @@
import json
from typing import Callable
from aiohttp.client import ClientSession
origin_request = ClientSession.request
class MockAioResponse:
check_funcs: dict[tuple[str, str], Callable[[dict], str]] = {}
rsp_cache: dict[str, str] = {}
name = "aiohttp"
def __init__(self, session, method, url, **kwargs) -> None:
fn = self.check_funcs.get((method, url))
self.key = f"{self.name}-{method}-{url}-{fn(kwargs) if fn else json.dumps(kwargs, sort_keys=True)}"
self.mng = self.response = None
if self.key not in self.rsp_cache:
self.mng = origin_request(session, method, url, **kwargs)
async def __aenter__(self):
if self.response:
await self.response.__aenter__()
elif self.mng:
self.response = await self.mng.__aenter__()
return self
async def __aexit__(self, *args, **kwargs):
if self.response:
await self.response.__aexit__(*args, **kwargs)
self.response = None
elif self.mng:
await self.mng.__aexit__(*args, **kwargs)
self.mng = None
async def json(self, *args, **kwargs):
if self.key in self.rsp_cache:
return self.rsp_cache[self.key]
data = await self.response.json(*args, **kwargs)
self.rsp_cache[self.key] = data
return data

View file

@ -0,0 +1,22 @@
import json
from typing import Callable
from curl_cffi import requests
origin_request = requests.Session.request
class MockCurlCffiResponse(requests.Response):
check_funcs: dict[tuple[str, str], Callable[[dict], str]] = {}
rsp_cache: dict[str, str] = {}
name = "curl-cffi"
def __init__(self, session, method, url, **kwargs) -> None:
super().__init__()
fn = self.check_funcs.get((method, url))
self.key = f"{self.name}-{method}-{url}-{fn(kwargs) if fn else json.dumps(kwargs, sort_keys=True)}"
self.response = None
if self.key not in self.rsp_cache:
response = origin_request(session, method, url, **kwargs)
self.rsp_cache[self.key] = response.content.decode()
self.content = self.rsp_cache[self.key].encode()

View file

@ -0,0 +1,29 @@
import json
from typing import Callable
from urllib.parse import parse_qsl, urlparse
import httplib2
origin_request = httplib2.Http.request
class MockHttplib2Response(httplib2.Response):
check_funcs: dict[tuple[str, str], Callable[[dict], str]] = {}
rsp_cache: dict[str, str] = {}
name = "httplib2"
def __init__(self, http, uri, method="GET", **kwargs) -> None:
url = uri.split("?")[0]
result = urlparse(uri)
params = dict(parse_qsl(result.query))
fn = self.check_funcs.get((method, uri))
new_kwargs = {"params": params}
key = f"{self.name}-{method}-{url}-{fn(new_kwargs) if fn else json.dumps(new_kwargs)}"
if key not in self.rsp_cache:
_, self.content = origin_request(http, uri, method, **kwargs)
self.rsp_cache[key] = self.content.decode()
self.content = self.rsp_cache[key]
def __iter__(self):
yield self
yield self.content.encode()

View file

@ -1,18 +1,18 @@
import json
from typing import Optional, Union
from metagpt.config import CONFIG
from metagpt.config2 import config
from metagpt.logs import log_llm_stream, logger
from metagpt.provider.azure_openai_api import AzureOpenAILLM
from metagpt.provider.openai_api import OpenAILLM
from metagpt.schema import Message
OriginalLLM = OpenAILLM if not CONFIG.openai_api_type else AzureOpenAILLM
OriginalLLM = OpenAILLM if not config.openai_api_type else AzureOpenAILLM
class MockLLM(OriginalLLM):
def __init__(self, allow_open_api_call):
super().__init__()
super().__init__(config.get_openai_llm())
self.allow_open_api_call = allow_open_api_call
self.rsp_cache: dict = {}
self.rsp_candidates: list[dict] = [] # a test can have multiple calls with the same llm, thus a list
@ -47,7 +47,9 @@ class MockLLM(OriginalLLM):
if system_msgs:
message = self._system_msgs(system_msgs)
else:
message = [self._default_system_msg()] if self.use_system_prompt else []
message = [self._default_system_msg()]
if not self.use_system_prompt:
message = []
if format_msgs:
message.extend(format_msgs)
message.append(self._user_msg(msg))