add openai api call switch; fix ocr

This commit is contained in:
yzlin 2024-01-05 14:34:44 +08:00
parent ab04f610a3
commit 1249d12b6f
4 changed files with 30 additions and 22 deletions

View file

@ -88,6 +88,8 @@ class InvoiceOCR(Action):
async def _ocr(invoice_file_path: Path):
ocr = PaddleOCR(use_angle_cls=True, lang="ch", page_num=1)
ocr_result = ocr.ocr(str(invoice_file_path), cls=True)
for result in ocr_result[0]:
result[1] = (result[1][0], round(result[1][1], 2)) # round long confidence scores to reduce token costs
return ocr_result
async def run(self, file_path: Path, *args, **kwargs) -> list:

View file

@ -23,6 +23,8 @@ from metagpt.utils.git_repository import GitRepository
from tests.mock.mock_llm import MockLLM
RSP_CACHE_NEW = {} # used globally for producing new and useful only response cache
ALLOW_OPENAI_API_CALL = os.environ.get("ALLOW_OPENAI_API_CALL", False)
ALLOW_OPENAI_API_CALL = True
@pytest.fixture(scope="session")
@ -53,7 +55,7 @@ def pytest_runtest_makereport(item, call):
@pytest.fixture(scope="function", autouse=True)
def llm_mock(rsp_cache, mocker, request):
llm = MockLLM()
llm = MockLLM(allow_open_api_call=ALLOW_OPENAI_API_CALL)
llm.rsp_cache = rsp_cache
mocker.patch("metagpt.provider.base_llm.BaseLLM.aask", llm.aask)
mocker.patch("metagpt.provider.base_llm.BaseLLM.aask_batch", llm.aask_batch)

File diff suppressed because one or more lines are too long

View file

@ -5,8 +5,9 @@ from metagpt.provider.openai_api import OpenAILLM
class MockLLM(OpenAILLM):
def __init__(self):
def __init__(self, allow_open_api_call):
super().__init__()
self.allow_open_api_call = allow_open_api_call
self.rsp_cache: dict = {}
self.rsp_candidates: list[dict] = [] # a test can have multiple calls with the same llm, thus a list
@ -69,20 +70,24 @@ class MockLLM(OpenAILLM):
if system_msgs:
joined_system_msg = "#MSG_SEP#".join(system_msgs) + "#SYSTEM_MSG_END#"
msg_key = joined_system_msg + msg_key
if msg_key not in self.rsp_cache:
# Call the original unmocked method
rsp = await self.original_aask(msg, system_msgs, format_msgs, timeout, stream)
else:
logger.warning("Use response cache")
rsp = self.rsp_cache[msg_key]
self.rsp_candidates.append({msg_key: rsp})
rsp = await self._mock_rsp(msg_key, self.original_aask, msg, system_msgs, format_msgs, timeout, stream)
return rsp
async def aask_batch(self, msgs: list, timeout=3) -> str:
msg_key = "#MSG_SEP#".join([msg if isinstance(msg, str) else msg.content for msg in msgs])
rsp = await self._mock_rsp(msg_key, self.original_aask_batch, msgs, timeout)
return rsp
async def _mock_rsp(self, msg_key, ask_func, *args, **kwargs):
if msg_key not in self.rsp_cache:
if not self.allow_open_api_call:
raise ValueError(
"In current test setting, api call is not allowed, you should properly mock your tests, "
"or add expected api response in tests/data/rsp_cache.json. "
f"The prompt you want for api call: {msg_key}"
)
# Call the original unmocked method
rsp = await self.original_aask_batch(msgs, timeout)
rsp = await ask_func(*args, **kwargs)
else:
logger.warning("Use response cache")
rsp = self.rsp_cache[msg_key]