tool management at one place, add aask_code mock, azure mock

This commit is contained in:
yzlin 2024-01-11 22:55:31 +08:00
parent 9e0b9745be
commit e99c5f29f4
9 changed files with 167 additions and 74 deletions

View file

@ -1,10 +1,16 @@
from typing import Optional
import json
from typing import Optional, Union
from metagpt.config import CONFIG
from metagpt.logs import log_llm_stream, logger
from metagpt.provider.azure_openai_api import AzureOpenAILLM
from metagpt.provider.openai_api import OpenAILLM
from metagpt.schema import Message
OriginalLLM = OpenAILLM if not CONFIG.openai_api_type else AzureOpenAILLM
class MockLLM(OpenAILLM):
class MockLLM(OriginalLLM):
def __init__(self, allow_open_api_call):
super().__init__()
self.allow_open_api_call = allow_open_api_call
@ -58,6 +64,15 @@ class MockLLM(OpenAILLM):
context.append(self._assistant_msg(rsp_text))
return self._extract_assistant_rsp(context)
async def original_aask_code(self, messages: Union[str, Message, list[dict]], **kwargs) -> dict:
"""
A copy of metagpt.provider.openai_api.OpenAILLM.aask_code, we can't use super().aask because it will be mocked.
Since openai_api.OpenAILLM.aask_code is different from base_llm.BaseLLM.aask_code, we use the former.
"""
messages = self._process_message(messages)
rsp = await self._achat_completion_function(messages, **kwargs)
return self.get_choice_function_arguments(rsp)
async def aask(
self,
msg: str,
@ -78,6 +93,12 @@ class MockLLM(OpenAILLM):
rsp = await self._mock_rsp(msg_key, self.original_aask_batch, msgs, timeout)
return rsp
async def aask_code(self, messages: Union[str, Message, list[dict]], **kwargs) -> dict:
messages = self._process_message(messages)
msg_key = json.dumps(messages, ensure_ascii=False)
rsp = await self._mock_rsp(msg_key, self.original_aask_code, messages, **kwargs)
return rsp
async def _mock_rsp(self, msg_key, ask_func, *args, **kwargs):
if msg_key not in self.rsp_cache:
if not self.allow_open_api_call: