add openai api call switch; fix ocr

This commit is contained in:
yzlin 2024-01-05 14:34:44 +08:00
parent ab04f610a3
commit 1249d12b6f
4 changed files with 30 additions and 22 deletions

View file

@ -5,8 +5,9 @@ from metagpt.provider.openai_api import OpenAILLM
class MockLLM(OpenAILLM):
def __init__(self):
def __init__(self, allow_open_api_call):
super().__init__()
self.allow_open_api_call = allow_open_api_call
self.rsp_cache: dict = {}
self.rsp_candidates: list[dict] = [] # a test can have multiple calls with the same llm, thus a list
@ -69,20 +70,24 @@ class MockLLM(OpenAILLM):
if system_msgs:
joined_system_msg = "#MSG_SEP#".join(system_msgs) + "#SYSTEM_MSG_END#"
msg_key = joined_system_msg + msg_key
if msg_key not in self.rsp_cache:
# Call the original unmocked method
rsp = await self.original_aask(msg, system_msgs, format_msgs, timeout, stream)
else:
logger.warning("Use response cache")
rsp = self.rsp_cache[msg_key]
self.rsp_candidates.append({msg_key: rsp})
rsp = await self._mock_rsp(msg_key, self.original_aask, msg, system_msgs, format_msgs, timeout, stream)
return rsp
async def aask_batch(self, msgs: list, timeout=3) -> str:
msg_key = "#MSG_SEP#".join([msg if isinstance(msg, str) else msg.content for msg in msgs])
rsp = await self._mock_rsp(msg_key, self.original_aask_batch, msgs, timeout)
return rsp
async def _mock_rsp(self, msg_key, ask_func, *args, **kwargs):
if msg_key not in self.rsp_cache:
if not self.allow_open_api_call:
raise ValueError(
"In current test setting, api call is not allowed, you should properly mock your tests, "
"or add expected api response in tests/data/rsp_cache.json. "
f"The prompt you want for api call: {msg_key}"
)
# Call the original unmocked method
rsp = await self.original_aask_batch(msgs, timeout)
rsp = await ask_func(*args, **kwargs)
else:
logger.warning("Use response cache")
rsp = self.rsp_cache[msg_key]