tool management at one place, add aask_code mock, azure mock

This commit is contained in:
yzlin 2024-01-11 22:55:31 +08:00
parent 9e0b9745be
commit e99c5f29f4
9 changed files with 167 additions and 74 deletions

View file

@ -34,14 +34,14 @@ def rsp_cache():
rsp_cache_file_path = TEST_DATA_PATH / "rsp_cache.json" # read repo-provided
new_rsp_cache_file_path = TEST_DATA_PATH / "rsp_cache_new.json" # exporting a new copy
if os.path.exists(rsp_cache_file_path):
with open(rsp_cache_file_path, "r") as f1:
with open(rsp_cache_file_path, "r", encoding="utf-8") as f1:
rsp_cache_json = json.load(f1)
else:
rsp_cache_json = {}
yield rsp_cache_json
with open(rsp_cache_file_path, "w") as f2:
with open(rsp_cache_file_path, "w", encoding="utf-8") as f2:
json.dump(rsp_cache_json, f2, indent=4, ensure_ascii=False)
with open(new_rsp_cache_file_path, "w") as f2:
with open(new_rsp_cache_file_path, "w", encoding="utf-8") as f2:
json.dump(RSP_CACHE_NEW, f2, indent=4, ensure_ascii=False)
@ -60,6 +60,7 @@ def llm_mock(rsp_cache, mocker, request):
llm.rsp_cache = rsp_cache
mocker.patch("metagpt.provider.base_llm.BaseLLM.aask", llm.aask)
mocker.patch("metagpt.provider.base_llm.BaseLLM.aask_batch", llm.aask_batch)
mocker.patch("metagpt.provider.openai_api.OpenAILLM.aask_code", llm.aask_code)
yield mocker
if hasattr(request.node, "test_outcome") and request.node.test_outcome.passed:
if llm.rsp_candidates:
@ -67,7 +68,7 @@ def llm_mock(rsp_cache, mocker, request):
cand_key = list(rsp_candidate.keys())[0]
cand_value = list(rsp_candidate.values())[0]
if cand_key not in llm.rsp_cache:
logger.info(f"Added '{cand_key[:100]} ... -> {cand_value[:20]} ...' to response cache")
logger.info(f"Added '{cand_key[:100]} ... -> {str(cand_value)[:20]} ...' to response cache")
llm.rsp_cache.update(rsp_candidate)
RSP_CACHE_NEW.update(rsp_candidate)

View file

@ -1,4 +1,12 @@
from metagpt.actions.write_plan import Plan, Task, precheck_update_plan_from_rsp
import pytest
from metagpt.actions.write_plan import (
Plan,
Task,
WritePlan,
precheck_update_plan_from_rsp,
)
from metagpt.schema import Message
def test_precheck_update_plan_from_rsp():
@ -12,3 +20,12 @@ def test_precheck_update_plan_from_rsp():
invalid_rsp = "wrong"
success, _ = precheck_update_plan_from_rsp(invalid_rsp, plan)
assert not success
@pytest.mark.asyncio
async def test_write_plan():
rsp = await WritePlan().run(context=[Message("run analysis on sklearn iris dataset", role="user")])
assert "task_id" in rsp
assert "instruction" in rsp
assert "json" not in rsp # the output should be the content inside ```json ```

View file

@ -0,0 +1,13 @@
import pytest
from metagpt.logs import logger
from metagpt.roles.code_interpreter import CodeInterpreter
@pytest.mark.asyncio
async def test_code_interpreter():
requirement = "Run data analysis on sklearn Iris dataset, include a plot"
ci = CodeInterpreter(goal=requirement, auto_run=True, use_tools=False)
rsp = await ci.run(requirement)
logger.info(rsp)
assert len(rsp.content) > 0

View file

@ -1,10 +1,16 @@
from typing import Optional
import json
from typing import Optional, Union
from metagpt.config import CONFIG
from metagpt.logs import log_llm_stream, logger
from metagpt.provider.azure_openai_api import AzureOpenAILLM
from metagpt.provider.openai_api import OpenAILLM
from metagpt.schema import Message
OriginalLLM = OpenAILLM if not CONFIG.openai_api_type else AzureOpenAILLM
class MockLLM(OpenAILLM):
class MockLLM(OriginalLLM):
def __init__(self, allow_open_api_call):
super().__init__()
self.allow_open_api_call = allow_open_api_call
@ -58,6 +64,15 @@ class MockLLM(OpenAILLM):
context.append(self._assistant_msg(rsp_text))
return self._extract_assistant_rsp(context)
async def original_aask_code(self, messages: Union[str, Message, list[dict]], **kwargs) -> dict:
"""
A copy of metagpt.provider.openai_api.OpenAILLM.aask_code, we can't use super().aask because it will be mocked.
Since openai_api.OpenAILLM.aask_code is different from base_llm.BaseLLM.aask_code, we use the former.
"""
messages = self._process_message(messages)
rsp = await self._achat_completion_function(messages, **kwargs)
return self.get_choice_function_arguments(rsp)
async def aask(
self,
msg: str,
@ -78,6 +93,12 @@ class MockLLM(OpenAILLM):
rsp = await self._mock_rsp(msg_key, self.original_aask_batch, msgs, timeout)
return rsp
async def aask_code(self, messages: Union[str, Message, list[dict]], **kwargs) -> dict:
messages = self._process_message(messages)
msg_key = json.dumps(messages, ensure_ascii=False)
rsp = await self._mock_rsp(msg_key, self.original_aask_code, messages, **kwargs)
return rsp
async def _mock_rsp(self, msg_key, ask_func, *args, **kwargs):
if msg_key not in self.rsp_cache:
if not self.allow_open_api_call: