mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-01 11:56:24 +02:00
tool management at one place, add aask_code mock, azure mock
This commit is contained in:
parent
9e0b9745be
commit
e99c5f29f4
9 changed files with 167 additions and 74 deletions
|
|
@ -34,14 +34,14 @@ def rsp_cache():
|
|||
rsp_cache_file_path = TEST_DATA_PATH / "rsp_cache.json" # read repo-provided
|
||||
new_rsp_cache_file_path = TEST_DATA_PATH / "rsp_cache_new.json" # exporting a new copy
|
||||
if os.path.exists(rsp_cache_file_path):
|
||||
with open(rsp_cache_file_path, "r") as f1:
|
||||
with open(rsp_cache_file_path, "r", encoding="utf-8") as f1:
|
||||
rsp_cache_json = json.load(f1)
|
||||
else:
|
||||
rsp_cache_json = {}
|
||||
yield rsp_cache_json
|
||||
with open(rsp_cache_file_path, "w") as f2:
|
||||
with open(rsp_cache_file_path, "w", encoding="utf-8") as f2:
|
||||
json.dump(rsp_cache_json, f2, indent=4, ensure_ascii=False)
|
||||
with open(new_rsp_cache_file_path, "w") as f2:
|
||||
with open(new_rsp_cache_file_path, "w", encoding="utf-8") as f2:
|
||||
json.dump(RSP_CACHE_NEW, f2, indent=4, ensure_ascii=False)
|
||||
|
||||
|
||||
|
|
@ -60,6 +60,7 @@ def llm_mock(rsp_cache, mocker, request):
|
|||
llm.rsp_cache = rsp_cache
|
||||
mocker.patch("metagpt.provider.base_llm.BaseLLM.aask", llm.aask)
|
||||
mocker.patch("metagpt.provider.base_llm.BaseLLM.aask_batch", llm.aask_batch)
|
||||
mocker.patch("metagpt.provider.openai_api.OpenAILLM.aask_code", llm.aask_code)
|
||||
yield mocker
|
||||
if hasattr(request.node, "test_outcome") and request.node.test_outcome.passed:
|
||||
if llm.rsp_candidates:
|
||||
|
|
@ -67,7 +68,7 @@ def llm_mock(rsp_cache, mocker, request):
|
|||
cand_key = list(rsp_candidate.keys())[0]
|
||||
cand_value = list(rsp_candidate.values())[0]
|
||||
if cand_key not in llm.rsp_cache:
|
||||
logger.info(f"Added '{cand_key[:100]} ... -> {cand_value[:20]} ...' to response cache")
|
||||
logger.info(f"Added '{cand_key[:100]} ... -> {str(cand_value)[:20]} ...' to response cache")
|
||||
llm.rsp_cache.update(rsp_candidate)
|
||||
RSP_CACHE_NEW.update(rsp_candidate)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,12 @@
|
|||
from metagpt.actions.write_plan import Plan, Task, precheck_update_plan_from_rsp
|
||||
import pytest
|
||||
|
||||
from metagpt.actions.write_plan import (
|
||||
Plan,
|
||||
Task,
|
||||
WritePlan,
|
||||
precheck_update_plan_from_rsp,
|
||||
)
|
||||
from metagpt.schema import Message
|
||||
|
||||
|
||||
def test_precheck_update_plan_from_rsp():
|
||||
|
|
@ -12,3 +20,12 @@ def test_precheck_update_plan_from_rsp():
|
|||
invalid_rsp = "wrong"
|
||||
success, _ = precheck_update_plan_from_rsp(invalid_rsp, plan)
|
||||
assert not success
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_plan():
|
||||
rsp = await WritePlan().run(context=[Message("run analysis on sklearn iris dataset", role="user")])
|
||||
|
||||
assert "task_id" in rsp
|
||||
assert "instruction" in rsp
|
||||
assert "json" not in rsp # the output should be the content inside ```json ```
|
||||
|
|
|
|||
13
tests/metagpt/roles/test_code_interpreter.py
Normal file
13
tests/metagpt/roles/test_code_interpreter.py
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.logs import logger
|
||||
from metagpt.roles.code_interpreter import CodeInterpreter
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_code_interpreter():
|
||||
requirement = "Run data analysis on sklearn Iris dataset, include a plot"
|
||||
ci = CodeInterpreter(goal=requirement, auto_run=True, use_tools=False)
|
||||
rsp = await ci.run(requirement)
|
||||
logger.info(rsp)
|
||||
assert len(rsp.content) > 0
|
||||
|
|
@ -1,10 +1,16 @@
|
|||
from typing import Optional
|
||||
import json
|
||||
from typing import Optional, Union
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.logs import log_llm_stream, logger
|
||||
from metagpt.provider.azure_openai_api import AzureOpenAILLM
|
||||
from metagpt.provider.openai_api import OpenAILLM
|
||||
from metagpt.schema import Message
|
||||
|
||||
OriginalLLM = OpenAILLM if not CONFIG.openai_api_type else AzureOpenAILLM
|
||||
|
||||
|
||||
class MockLLM(OpenAILLM):
|
||||
class MockLLM(OriginalLLM):
|
||||
def __init__(self, allow_open_api_call):
|
||||
super().__init__()
|
||||
self.allow_open_api_call = allow_open_api_call
|
||||
|
|
@ -58,6 +64,15 @@ class MockLLM(OpenAILLM):
|
|||
context.append(self._assistant_msg(rsp_text))
|
||||
return self._extract_assistant_rsp(context)
|
||||
|
||||
async def original_aask_code(self, messages: Union[str, Message, list[dict]], **kwargs) -> dict:
|
||||
"""
|
||||
A copy of metagpt.provider.openai_api.OpenAILLM.aask_code, we can't use super().aask because it will be mocked.
|
||||
Since openai_api.OpenAILLM.aask_code is different from base_llm.BaseLLM.aask_code, we use the former.
|
||||
"""
|
||||
messages = self._process_message(messages)
|
||||
rsp = await self._achat_completion_function(messages, **kwargs)
|
||||
return self.get_choice_function_arguments(rsp)
|
||||
|
||||
async def aask(
|
||||
self,
|
||||
msg: str,
|
||||
|
|
@ -78,6 +93,12 @@ class MockLLM(OpenAILLM):
|
|||
rsp = await self._mock_rsp(msg_key, self.original_aask_batch, msgs, timeout)
|
||||
return rsp
|
||||
|
||||
async def aask_code(self, messages: Union[str, Message, list[dict]], **kwargs) -> dict:
|
||||
messages = self._process_message(messages)
|
||||
msg_key = json.dumps(messages, ensure_ascii=False)
|
||||
rsp = await self._mock_rsp(msg_key, self.original_aask_code, messages, **kwargs)
|
||||
return rsp
|
||||
|
||||
async def _mock_rsp(self, msg_key, ask_func, *args, **kwargs):
|
||||
if msg_key not in self.rsp_cache:
|
||||
if not self.allow_open_api_call:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue