mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-04-26 09:16:21 +02:00
Merge branch 'code_intepreter' into sd_and_debugcode_ut
This commit is contained in:
commit
192e2aa807
13 changed files with 180 additions and 855 deletions
|
|
@ -60,6 +60,7 @@ def llm_mock(rsp_cache, mocker, request):
|
|||
llm.rsp_cache = rsp_cache
|
||||
mocker.patch("metagpt.provider.base_llm.BaseLLM.aask", llm.aask)
|
||||
mocker.patch("metagpt.provider.base_llm.BaseLLM.aask_batch", llm.aask_batch)
|
||||
mocker.patch("metagpt.provider.openai_api.OpenAILLM.aask_code", llm.aask_code)
|
||||
yield mocker
|
||||
if hasattr(request.node, "test_outcome") and request.node.test_outcome.passed:
|
||||
if llm.rsp_candidates:
|
||||
|
|
@ -67,7 +68,7 @@ def llm_mock(rsp_cache, mocker, request):
|
|||
cand_key = list(rsp_candidate.keys())[0]
|
||||
cand_value = list(rsp_candidate.values())[0]
|
||||
if cand_key not in llm.rsp_cache:
|
||||
logger.info(f"Added '{cand_key[:100]} ... -> {cand_value[:20]} ...' to response cache")
|
||||
logger.info(f"Added '{cand_key[:100]} ... -> {str(cand_value)[:20]} ...' to response cache")
|
||||
llm.rsp_cache.update(rsp_candidate)
|
||||
RSP_CACHE_NEW.update(rsp_candidate)
|
||||
|
||||
|
|
|
|||
File diff suppressed because one or more lines are too long
|
|
@ -1,4 +1,12 @@
|
|||
from metagpt.actions.write_plan import Plan, Task, precheck_update_plan_from_rsp
|
||||
import pytest
|
||||
|
||||
from metagpt.actions.write_plan import (
|
||||
Plan,
|
||||
Task,
|
||||
WritePlan,
|
||||
precheck_update_plan_from_rsp,
|
||||
)
|
||||
from metagpt.schema import Message
|
||||
|
||||
|
||||
def test_precheck_update_plan_from_rsp():
|
||||
|
|
@ -12,3 +20,12 @@ def test_precheck_update_plan_from_rsp():
|
|||
invalid_rsp = "wrong"
|
||||
success, _ = precheck_update_plan_from_rsp(invalid_rsp, plan)
|
||||
assert not success
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_plan():
|
||||
rsp = await WritePlan().run(context=[Message("run analysis on sklearn iris dataset", role="user")])
|
||||
|
||||
assert "task_id" in rsp
|
||||
assert "instruction" in rsp
|
||||
assert "json" not in rsp # the output should be the content inside ```json ```
|
||||
|
|
|
|||
13
tests/metagpt/roles/test_code_interpreter.py
Normal file
13
tests/metagpt/roles/test_code_interpreter.py
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.logs import logger
|
||||
from metagpt.roles.code_interpreter import CodeInterpreter
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_code_interpreter():
|
||||
requirement = "Run data analysis on sklearn Iris dataset, include a plot"
|
||||
ci = CodeInterpreter(goal=requirement, auto_run=True, use_tools=False)
|
||||
rsp = await ci.run(requirement)
|
||||
logger.info(rsp)
|
||||
assert len(rsp.content) > 0
|
||||
|
|
@ -1,10 +1,16 @@
|
|||
from typing import Optional
|
||||
import json
|
||||
from typing import Optional, Union
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.logs import log_llm_stream, logger
|
||||
from metagpt.provider.azure_openai_api import AzureOpenAILLM
|
||||
from metagpt.provider.openai_api import OpenAILLM
|
||||
from metagpt.schema import Message
|
||||
|
||||
OriginalLLM = OpenAILLM if not CONFIG.openai_api_type else AzureOpenAILLM
|
||||
|
||||
|
||||
class MockLLM(OpenAILLM):
|
||||
class MockLLM(OriginalLLM):
|
||||
def __init__(self, allow_open_api_call):
|
||||
super().__init__()
|
||||
self.allow_open_api_call = allow_open_api_call
|
||||
|
|
@ -58,6 +64,15 @@ class MockLLM(OpenAILLM):
|
|||
context.append(self._assistant_msg(rsp_text))
|
||||
return self._extract_assistant_rsp(context)
|
||||
|
||||
async def original_aask_code(self, messages: Union[str, Message, list[dict]], **kwargs) -> dict:
|
||||
"""
|
||||
A copy of metagpt.provider.openai_api.OpenAILLM.aask_code, we can't use super().aask because it will be mocked.
|
||||
Since openai_api.OpenAILLM.aask_code is different from base_llm.BaseLLM.aask_code, we use the former.
|
||||
"""
|
||||
messages = self._process_message(messages)
|
||||
rsp = await self._achat_completion_function(messages, **kwargs)
|
||||
return self.get_choice_function_arguments(rsp)
|
||||
|
||||
async def aask(
|
||||
self,
|
||||
msg: str,
|
||||
|
|
@ -78,6 +93,12 @@ class MockLLM(OpenAILLM):
|
|||
rsp = await self._mock_rsp(msg_key, self.original_aask_batch, msgs, timeout)
|
||||
return rsp
|
||||
|
||||
async def aask_code(self, messages: Union[str, Message, list[dict]], **kwargs) -> dict:
|
||||
messages = self._process_message(messages)
|
||||
msg_key = json.dumps(messages, ensure_ascii=False)
|
||||
rsp = await self._mock_rsp(msg_key, self.original_aask_code, messages, **kwargs)
|
||||
return rsp
|
||||
|
||||
async def _mock_rsp(self, msg_key, ask_func, *args, **kwargs):
|
||||
if msg_key not in self.rsp_cache:
|
||||
if not self.allow_open_api_call:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue