mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-05 13:52:38 +02:00
add openai api call switch; fix ocr
This commit is contained in:
parent
ab04f610a3
commit
1249d12b6f
4 changed files with 30 additions and 22 deletions
|
|
@ -5,8 +5,9 @@ from metagpt.provider.openai_api import OpenAILLM
|
|||
|
||||
|
||||
class MockLLM(OpenAILLM):
|
||||
def __init__(self):
|
||||
def __init__(self, allow_open_api_call):
|
||||
super().__init__()
|
||||
self.allow_open_api_call = allow_open_api_call
|
||||
self.rsp_cache: dict = {}
|
||||
self.rsp_candidates: list[dict] = [] # a test can have multiple calls with the same llm, thus a list
|
||||
|
||||
|
|
@ -69,20 +70,24 @@ class MockLLM(OpenAILLM):
|
|||
if system_msgs:
|
||||
joined_system_msg = "#MSG_SEP#".join(system_msgs) + "#SYSTEM_MSG_END#"
|
||||
msg_key = joined_system_msg + msg_key
|
||||
if msg_key not in self.rsp_cache:
|
||||
# Call the original unmocked method
|
||||
rsp = await self.original_aask(msg, system_msgs, format_msgs, timeout, stream)
|
||||
else:
|
||||
logger.warning("Use response cache")
|
||||
rsp = self.rsp_cache[msg_key]
|
||||
self.rsp_candidates.append({msg_key: rsp})
|
||||
rsp = await self._mock_rsp(msg_key, self.original_aask, msg, system_msgs, format_msgs, timeout, stream)
|
||||
return rsp
|
||||
|
||||
async def aask_batch(self, msgs: list, timeout=3) -> str:
|
||||
msg_key = "#MSG_SEP#".join([msg if isinstance(msg, str) else msg.content for msg in msgs])
|
||||
rsp = await self._mock_rsp(msg_key, self.original_aask_batch, msgs, timeout)
|
||||
return rsp
|
||||
|
||||
async def _mock_rsp(self, msg_key, ask_func, *args, **kwargs):
|
||||
if msg_key not in self.rsp_cache:
|
||||
if not self.allow_open_api_call:
|
||||
raise ValueError(
|
||||
"In current test setting, api call is not allowed, you should properly mock your tests, "
|
||||
"or add expected api response in tests/data/rsp_cache.json. "
|
||||
f"The prompt you want for api call: {msg_key}"
|
||||
)
|
||||
# Call the original unmocked method
|
||||
rsp = await self.original_aask_batch(msgs, timeout)
|
||||
rsp = await ask_func(*args, **kwargs)
|
||||
else:
|
||||
logger.warning("Use response cache")
|
||||
rsp = self.rsp_cache[msg_key]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue