Merge remote-tracking branch 'origin/main' into main-debugger

This commit is contained in:
shenchucheng 2024-02-08 15:36:35 +08:00
commit 924e48e510
333 changed files with 14173 additions and 1813 deletions

View file

@ -14,7 +14,8 @@ class MockAioResponse:
def __init__(self, session, method, url, **kwargs) -> None:
fn = self.check_funcs.get((method, url))
self.key = f"{self.name}-{method}-{url}-{fn(kwargs) if fn else json.dumps(kwargs, sort_keys=True)}"
_kwargs = {k: v for k, v in kwargs.items() if k != "proxy"}
self.key = f"{self.name}-{method}-{url}-{fn(kwargs) if fn else json.dumps(_kwargs, sort_keys=True)}"
self.mng = self.response = None
if self.key not in self.rsp_cache:
self.mng = origin_request(session, method, url, **kwargs)
@ -52,3 +53,7 @@ class MockAioResponse:
data = await self.response.content.read()
self.rsp_cache[self.key] = str(data)
return data
def raise_for_status(self):
if self.response:
self.response.raise_for_status()

View file

@ -1,13 +1,22 @@
from typing import Optional
import json
from typing import Optional, Union
from metagpt.config2 import config
from metagpt.configs.llm_config import LLMType
from metagpt.logs import log_llm_stream, logger
from metagpt.provider.azure_openai_api import AzureOpenAILLM
from metagpt.provider.openai_api import OpenAILLM
from metagpt.schema import Message
OriginalLLM = OpenAILLM if config.llm.api_type == LLMType.OPENAI else AzureOpenAILLM
class MockLLM(OpenAILLM):
class MockLLM(OriginalLLM):
def __init__(self, allow_open_api_call):
super().__init__(config.get_openai_llm())
original_llm_config = (
config.get_openai_llm() if config.llm.api_type == LLMType.OPENAI else config.get_azure_llm()
)
super().__init__(original_llm_config)
self.allow_open_api_call = allow_open_api_call
self.rsp_cache: dict = {}
self.rsp_candidates: list[dict] = [] # a test can have multiple calls with the same llm, thus a list
@ -35,6 +44,7 @@ class MockLLM(OpenAILLM):
msg: str,
system_msgs: Optional[list[str]] = None,
format_msgs: Optional[list[dict[str, str]]] = None,
images: Optional[Union[str, list[str]]] = None,
timeout=3,
stream=True,
):
@ -47,7 +57,7 @@ class MockLLM(OpenAILLM):
message = []
if format_msgs:
message.extend(format_msgs)
message.append(self._user_msg(msg))
message.append(self._user_msg(msg, images=images))
rsp = await self.acompletion_text(message, stream=stream, timeout=timeout)
return rsp
@ -61,11 +71,20 @@ class MockLLM(OpenAILLM):
context.append(self._assistant_msg(rsp_text))
return self._extract_assistant_rsp(context)
async def original_aask_code(self, messages: Union[str, Message, list[dict]], **kwargs) -> dict:
"""
A copy of metagpt.provider.openai_api.OpenAILLM.aask_code, we can't use super().aask because it will be mocked.
Since openai_api.OpenAILLM.aask_code is different from base_llm.BaseLLM.aask_code, we use the former.
"""
rsp = await self._achat_completion_function(messages, **kwargs)
return self.get_choice_function_arguments(rsp)
async def aask(
self,
msg: str,
system_msgs: Optional[list[str]] = None,
format_msgs: Optional[list[dict[str, str]]] = None,
images: Optional[Union[str, list[str]]] = None,
timeout=3,
stream=True,
) -> str:
@ -73,7 +92,7 @@ class MockLLM(OpenAILLM):
if system_msgs:
joined_system_msg = "#MSG_SEP#".join(system_msgs) + "#SYSTEM_MSG_END#"
msg_key = joined_system_msg + msg_key
rsp = await self._mock_rsp(msg_key, self.original_aask, msg, system_msgs, format_msgs, timeout, stream)
rsp = await self._mock_rsp(msg_key, self.original_aask, msg, system_msgs, format_msgs, images, timeout, stream)
return rsp
async def aask_batch(self, msgs: list, timeout=3) -> str:
@ -81,6 +100,12 @@ class MockLLM(OpenAILLM):
rsp = await self._mock_rsp(msg_key, self.original_aask_batch, msgs, timeout)
return rsp
async def aask_code(self, messages: Union[str, Message, list[dict]], **kwargs) -> dict:
messages = self._process_message(messages)
msg_key = json.dumps(messages, ensure_ascii=False)
rsp = await self._mock_rsp(msg_key, self.original_aask_code, messages, **kwargs)
return rsp
async def _mock_rsp(self, msg_key, ask_func, *args, **kwargs):
if msg_key not in self.rsp_cache:
if not self.allow_open_api_call: