mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-18 13:55:17 +02:00
fixbug: unit test
This commit is contained in:
parent
2ed7c50822
commit
e350656725
19 changed files with 207 additions and 146 deletions
|
|
@ -28,10 +28,10 @@ async def test_text_to_embedding(mocker):
|
|||
type(config.get_openai_llm()).proxy = mocker.PropertyMock(return_value="http://mock.proxy")
|
||||
|
||||
# Prerequisites
|
||||
assert config.get_openai_llm()
|
||||
assert config.get_openai_llm().api_key
|
||||
assert config.get_openai_llm().proxy
|
||||
|
||||
v = await text_to_embedding(text="Panda emoji")
|
||||
v = await text_to_embedding(text="Panda emoji", config=config)
|
||||
assert len(v.data) > 0
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -29,9 +29,7 @@ async def test_text_to_image(mocker):
|
|||
config = Config.default()
|
||||
assert config.METAGPT_TEXT_TO_IMAGE_MODEL_URL
|
||||
|
||||
data = await text_to_image(
|
||||
"Panda emoji", size_type="512x512", model_url=config.METAGPT_TEXT_TO_IMAGE_MODEL_URL, config=config
|
||||
)
|
||||
data = await text_to_image("Panda emoji", size_type="512x512", config=config)
|
||||
assert "base64" in data or "http" in data
|
||||
|
||||
|
||||
|
|
@ -54,6 +52,7 @@ async def test_openai_text_to_image(mocker):
|
|||
mocker.patch.object(S3, "cache", return_value="http://mock.s3.com/0.png")
|
||||
|
||||
config = Config.default()
|
||||
config.METAGPT_TEXT_TO_IMAGE_MODEL_URL = None
|
||||
assert config.get_openai_llm()
|
||||
|
||||
data = await text_to_image("Panda emoji", size_type="512x512", config=config)
|
||||
|
|
|
|||
|
|
@ -8,43 +8,64 @@
|
|||
"""
|
||||
|
||||
import pytest
|
||||
from azure.cognitiveservices.speech import ResultReason, SpeechSynthesizer
|
||||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.config2 import Config
|
||||
from metagpt.learn.text_to_speech import text_to_speech
|
||||
from metagpt.tools.iflytek_tts import IFlyTekTTS
|
||||
from metagpt.utils.s3 import S3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_to_speech():
|
||||
async def test_azure_text_to_speech(mocker):
|
||||
# mock
|
||||
config = Config.default()
|
||||
config.IFLYTEK_API_KEY = None
|
||||
config.IFLYTEK_API_SECRET = None
|
||||
config.IFLYTEK_APP_ID = None
|
||||
mock_result = mocker.Mock()
|
||||
mock_result.audio_data = b"mock audio data"
|
||||
mock_result.reason = ResultReason.SynthesizingAudioCompleted
|
||||
mock_data = mocker.Mock()
|
||||
mock_data.get.return_value = mock_result
|
||||
mocker.patch.object(SpeechSynthesizer, "speak_ssml_async", return_value=mock_data)
|
||||
mocker.patch.object(S3, "cache", return_value="http://mock.s3.com/1.wav")
|
||||
|
||||
# Prerequisites
|
||||
assert not config.IFLYTEK_APP_ID
|
||||
assert not config.IFLYTEK_API_KEY
|
||||
assert not config.IFLYTEK_API_SECRET
|
||||
assert config.AZURE_TTS_SUBSCRIPTION_KEY and config.AZURE_TTS_SUBSCRIPTION_KEY != "YOUR_API_KEY"
|
||||
assert config.AZURE_TTS_REGION
|
||||
|
||||
config.copy()
|
||||
# test azure
|
||||
data = await text_to_speech("panda emoji", config=config)
|
||||
assert "base64" in data or "http" in data
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_iflytek_text_to_speech(mocker):
|
||||
# mock
|
||||
config = Config.default()
|
||||
config.AZURE_TTS_SUBSCRIPTION_KEY = None
|
||||
config.AZURE_TTS_REGION = None
|
||||
mocker.patch.object(IFlyTekTTS, "synthesize_speech", return_value=None)
|
||||
mock_data = mocker.AsyncMock()
|
||||
mock_data.read.return_value = b"mock iflytek"
|
||||
mock_reader = mocker.patch("aiofiles.open")
|
||||
mock_reader.return_value.__aenter__.return_value = mock_data
|
||||
mocker.patch.object(S3, "cache", return_value="http://mock.s3.com/1.mp3")
|
||||
|
||||
# Prerequisites
|
||||
assert config.IFLYTEK_APP_ID
|
||||
assert config.IFLYTEK_API_KEY
|
||||
assert config.IFLYTEK_API_SECRET
|
||||
assert config.AZURE_TTS_SUBSCRIPTION_KEY and config.AZURE_TTS_SUBSCRIPTION_KEY != "YOUR_API_KEY"
|
||||
assert config.AZURE_TTS_REGION
|
||||
assert not config.AZURE_TTS_SUBSCRIPTION_KEY or config.AZURE_TTS_SUBSCRIPTION_KEY == "YOUR_API_KEY"
|
||||
assert not config.AZURE_TTS_REGION
|
||||
|
||||
i = config.copy()
|
||||
# test azure
|
||||
data = await text_to_speech(
|
||||
"panda emoji",
|
||||
subscription_key=i.AZURE_TTS_SUBSCRIPTION_KEY,
|
||||
region=i.AZURE_TTS_REGION,
|
||||
iflytek_api_key=i.IFLYTEK_API_KEY,
|
||||
iflytek_api_secret=i.IFLYTEK_API_SECRET,
|
||||
iflytek_app_id=i.IFLYTEK_APP_ID,
|
||||
)
|
||||
assert "base64" in data or "http" in data
|
||||
|
||||
# test iflytek
|
||||
## Mock session env
|
||||
i.AZURE_TTS_SUBSCRIPTION_KEY = ""
|
||||
data = await text_to_speech(
|
||||
"panda emoji",
|
||||
subscription_key=i.AZURE_TTS_SUBSCRIPTION_KEY,
|
||||
region=i.AZURE_TTS_REGION,
|
||||
iflytek_api_key=i.IFLYTEK_API_KEY,
|
||||
iflytek_api_secret=i.IFLYTEK_API_SECRET,
|
||||
iflytek_app_id=i.IFLYTEK_APP_ID,
|
||||
)
|
||||
data = await text_to_speech("panda emoji", config=config)
|
||||
assert "base64" in data or "http" in data
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -20,7 +20,10 @@ from metagpt.utils.common import any_to_str
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run():
|
||||
async def test_run(mocker):
|
||||
# mock
|
||||
mocker.patch("metagpt.learn.text_to_image", return_value="http://mock.com/1.png")
|
||||
|
||||
CONTEXT.kwargs.language = "Chinese"
|
||||
|
||||
class Input(BaseModel):
|
||||
|
|
@ -65,7 +68,7 @@ async def test_run():
|
|||
"cause_by": any_to_str(SkillAction),
|
||||
},
|
||||
]
|
||||
CONTEXT.kwargs.agent_skills = [
|
||||
agent_skills = [
|
||||
{"id": 1, "name": "text_to_speech", "type": "builtin", "config": {}, "enabled": True},
|
||||
{"id": 2, "name": "text_to_image", "type": "builtin", "config": {}, "enabled": True},
|
||||
{"id": 3, "name": "ai_call", "type": "builtin", "config": {}, "enabled": True},
|
||||
|
|
@ -77,9 +80,11 @@ async def test_run():
|
|||
|
||||
for i in inputs:
|
||||
seed = Input(**i)
|
||||
CONTEXT.kwargs.language = seed.language
|
||||
CONTEXT.kwargs.agent_description = seed.agent_description
|
||||
role = Assistant(language="Chinese")
|
||||
role.context.kwargs.language = seed.language
|
||||
role.context.kwargs.agent_description = seed.agent_description
|
||||
role.context.kwargs.agent_skills = agent_skills
|
||||
|
||||
role.memory = seed.memory # Restore historical conversation content.
|
||||
while True:
|
||||
has_action = await role.think()
|
||||
|
|
@ -112,6 +117,7 @@ async def test_run():
|
|||
@pytest.mark.asyncio
|
||||
async def test_memory(memory):
|
||||
role = Assistant()
|
||||
role.context.kwargs.agent_skills = []
|
||||
role.load_memory(memory)
|
||||
|
||||
val = role.get_memory()
|
||||
|
|
|
|||
|
|
@ -8,23 +8,25 @@
|
|||
distribution feature for message handling.
|
||||
"""
|
||||
import json
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.actions import WriteCode, WriteTasks
|
||||
from metagpt.const import (
|
||||
PRDS_FILE_REPO,
|
||||
DEFAULT_WORKSPACE_ROOT,
|
||||
REQUIREMENT_FILENAME,
|
||||
SYSTEM_DESIGN_FILE_REPO,
|
||||
TASK_FILE_REPO,
|
||||
)
|
||||
from metagpt.context import CONTEXT
|
||||
from metagpt.context import CONTEXT, Context
|
||||
from metagpt.logs import logger
|
||||
from metagpt.roles.engineer import Engineer
|
||||
from metagpt.schema import CodingContext, Message
|
||||
from metagpt.utils.common import CodeParser, any_to_name, any_to_str, aread, awrite
|
||||
from metagpt.utils.git_repository import ChangeType
|
||||
from metagpt.utils.git_repository import ChangeType, GitRepository
|
||||
from metagpt.utils.project_repo import ProjectRepo
|
||||
from tests.metagpt.roles.mock import STRS_FOR_PARSING, TASKS, MockMessages
|
||||
|
||||
|
||||
|
|
@ -32,20 +34,18 @@ from tests.metagpt.roles.mock import STRS_FOR_PARSING, TASKS, MockMessages
|
|||
async def test_engineer():
|
||||
# Prerequisites
|
||||
rqno = "20231221155954.json"
|
||||
await CONTEXT.file_repo.save_file(REQUIREMENT_FILENAME, content=MockMessages.req.content)
|
||||
await CONTEXT.file_repo.save_file(rqno, relative_path=PRDS_FILE_REPO, content=MockMessages.prd.content)
|
||||
await CONTEXT.file_repo.save_file(
|
||||
rqno, relative_path=SYSTEM_DESIGN_FILE_REPO, content=MockMessages.system_design.content
|
||||
)
|
||||
await CONTEXT.file_repo.save_file(rqno, relative_path=TASK_FILE_REPO, content=MockMessages.json_tasks.content)
|
||||
project_repo = ProjectRepo(CONTEXT.git_repo)
|
||||
await project_repo.save(REQUIREMENT_FILENAME, content=MockMessages.req.content)
|
||||
await project_repo.docs.prd.save(rqno, content=MockMessages.prd.content)
|
||||
await project_repo.docs.system_design.save(rqno, content=MockMessages.system_design.content)
|
||||
await project_repo.docs.task.save(rqno, content=MockMessages.json_tasks.content)
|
||||
|
||||
engineer = Engineer()
|
||||
rsp = await engineer.run(Message(content="", cause_by=WriteTasks))
|
||||
|
||||
logger.info(rsp)
|
||||
assert rsp.cause_by == any_to_str(WriteCode)
|
||||
src_file_repo = CONTEXT.git_repo.new_file_repository(CONTEXT.src_workspace)
|
||||
assert src_file_repo.changed_files
|
||||
assert project_repo.with_src_path(CONTEXT.src_workspace).srcs.changed_files
|
||||
|
||||
|
||||
def test_parse_str():
|
||||
|
|
@ -114,48 +114,50 @@ def test_todo():
|
|||
@pytest.mark.asyncio
|
||||
async def test_new_coding_context():
|
||||
# Prerequisites
|
||||
context = Context()
|
||||
context.git_repo = GitRepository(local_path=DEFAULT_WORKSPACE_ROOT / f"unittest/{uuid.uuid4().hex}")
|
||||
demo_path = Path(__file__).parent / "../../data/demo_project"
|
||||
deps = json.loads(await aread(demo_path / "dependencies.json"))
|
||||
dependency = await CONTEXT.git_repo.get_dependency()
|
||||
dependency = await context.git_repo.get_dependency()
|
||||
for k, v in deps.items():
|
||||
await dependency.update(k, set(v))
|
||||
data = await aread(demo_path / "system_design.json")
|
||||
rqno = "20231221155954.json"
|
||||
await awrite(CONTEXT.git_repo.workdir / SYSTEM_DESIGN_FILE_REPO / rqno, data)
|
||||
await awrite(context.git_repo.workdir / SYSTEM_DESIGN_FILE_REPO / rqno, data)
|
||||
data = await aread(demo_path / "tasks.json")
|
||||
await awrite(CONTEXT.git_repo.workdir / TASK_FILE_REPO / rqno, data)
|
||||
await awrite(context.git_repo.workdir / TASK_FILE_REPO / rqno, data)
|
||||
|
||||
CONTEXT.src_workspace = Path(CONTEXT.git_repo.workdir) / "game_2048"
|
||||
src_file_repo = CONTEXT.git_repo.new_file_repository(relative_path=CONTEXT.src_workspace)
|
||||
task_file_repo = CONTEXT.git_repo.new_file_repository(relative_path=TASK_FILE_REPO)
|
||||
design_file_repo = CONTEXT.git_repo.new_file_repository(relative_path=SYSTEM_DESIGN_FILE_REPO)
|
||||
context.src_workspace = Path(context.git_repo.workdir) / "game_2048"
|
||||
|
||||
filename = "game.py"
|
||||
ctx_doc = await Engineer._new_coding_doc(
|
||||
filename=filename,
|
||||
src_file_repo=src_file_repo,
|
||||
task_file_repo=task_file_repo,
|
||||
design_file_repo=design_file_repo,
|
||||
dependency=dependency,
|
||||
)
|
||||
assert ctx_doc
|
||||
assert ctx_doc.filename == filename
|
||||
assert ctx_doc.content
|
||||
ctx = CodingContext.model_validate_json(ctx_doc.content)
|
||||
assert ctx.filename == filename
|
||||
assert ctx.design_doc
|
||||
assert ctx.design_doc.content
|
||||
assert ctx.task_doc
|
||||
assert ctx.task_doc.content
|
||||
assert ctx.code_doc
|
||||
try:
|
||||
filename = "game.py"
|
||||
engineer = Engineer(context=context)
|
||||
ctx_doc = await engineer._new_coding_doc(
|
||||
filename=filename,
|
||||
dependency=dependency,
|
||||
)
|
||||
assert ctx_doc
|
||||
assert ctx_doc.filename == filename
|
||||
assert ctx_doc.content
|
||||
ctx = CodingContext.model_validate_json(ctx_doc.content)
|
||||
assert ctx.filename == filename
|
||||
assert ctx.design_doc
|
||||
assert ctx.design_doc.content
|
||||
assert ctx.task_doc
|
||||
assert ctx.task_doc.content
|
||||
assert ctx.code_doc
|
||||
|
||||
CONTEXT.git_repo.add_change({f"{TASK_FILE_REPO}/{rqno}": ChangeType.UNTRACTED})
|
||||
CONTEXT.git_repo.commit("mock env")
|
||||
await src_file_repo.save(filename=filename, content="content")
|
||||
role = Engineer()
|
||||
assert not role.code_todos
|
||||
await role._new_code_actions()
|
||||
assert role.code_todos
|
||||
context.git_repo.add_change({f"{TASK_FILE_REPO}/{rqno}": ChangeType.UNTRACTED})
|
||||
context.git_repo.commit("mock env")
|
||||
await ProjectRepo(context.git_repo).with_src_path(context.src_workspace).srcs.save(
|
||||
filename=filename, content="content"
|
||||
)
|
||||
role = Engineer(context=context)
|
||||
assert not role.code_todos
|
||||
await role._new_code_actions()
|
||||
assert role.code_todos
|
||||
finally:
|
||||
context.git_repo.delete_repository()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -8,15 +8,14 @@
|
|||
from typing import Dict, Optional
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from metagpt.context import CONTEXT
|
||||
from metagpt.context import Context
|
||||
from metagpt.roles.teacher import Teacher
|
||||
from metagpt.schema import Message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skip
|
||||
async def test_init():
|
||||
class Inputs(BaseModel):
|
||||
name: str
|
||||
|
|
@ -30,6 +29,7 @@ async def test_init():
|
|||
expect_goal: str
|
||||
expect_constraints: str
|
||||
expect_desc: str
|
||||
exclude: list = Field(default_factory=list)
|
||||
|
||||
inputs = [
|
||||
{
|
||||
|
|
@ -44,6 +44,7 @@ async def test_init():
|
|||
"kwargs": {},
|
||||
"desc": "aaa{language}",
|
||||
"expect_desc": "aaa{language}",
|
||||
"exclude": ["language", "key1", "something_big", "teaching_language"],
|
||||
},
|
||||
{
|
||||
"name": "Lily{language}",
|
||||
|
|
@ -57,13 +58,21 @@ async def test_init():
|
|||
"kwargs": {"language": "CN", "key1": "HaHa", "something_big": "sleep", "teaching_language": "EN"},
|
||||
"desc": "aaa{language}",
|
||||
"expect_desc": "aaaCN",
|
||||
"language": "CN",
|
||||
"teaching_language": "EN",
|
||||
},
|
||||
]
|
||||
|
||||
for i in inputs:
|
||||
seed = Inputs(**i)
|
||||
context = Context()
|
||||
for k in seed.exclude:
|
||||
context.kwargs.set(k, None)
|
||||
for k, v in seed.kwargs.items():
|
||||
context.kwargs.set(k, v)
|
||||
|
||||
teacher = Teacher(
|
||||
context=context,
|
||||
name=seed.name,
|
||||
profile=seed.profile,
|
||||
goal=seed.goal,
|
||||
|
|
@ -97,8 +106,6 @@ async def test_new_file_name():
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run():
|
||||
CONTEXT.kwargs.language = "Chinese"
|
||||
CONTEXT.kwargs.teaching_language = "English"
|
||||
lesson = """
|
||||
UNIT 1 Making New Friends
|
||||
TOPIC 1 Welcome to China!
|
||||
|
|
@ -142,7 +149,10 @@ async def test_run():
|
|||
|
||||
3c Match the big letters with the small ones. Then write them on the lines.
|
||||
"""
|
||||
teacher = Teacher()
|
||||
context = Context()
|
||||
context.kwargs.language = "Chinese"
|
||||
context.kwargs.teaching_language = "English"
|
||||
teacher = Teacher(context=context)
|
||||
rsp = await teacher.run(Message(content=lesson))
|
||||
assert rsp
|
||||
|
||||
|
|
|
|||
|
|
@ -7,12 +7,22 @@
|
|||
"""
|
||||
import pytest
|
||||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.tools.iflytek_tts import oas3_iflytek_tts
|
||||
from metagpt.config2 import Config
|
||||
from metagpt.tools.iflytek_tts import IFlyTekTTS, oas3_iflytek_tts
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tts():
|
||||
async def test_iflytek_tts(mocker):
|
||||
# mock
|
||||
config = Config.default()
|
||||
config.AZURE_TTS_SUBSCRIPTION_KEY = None
|
||||
config.AZURE_TTS_REGION = None
|
||||
mocker.patch.object(IFlyTekTTS, "synthesize_speech", return_value=None)
|
||||
mock_data = mocker.AsyncMock()
|
||||
mock_data.read.return_value = b"mock iflytek"
|
||||
mock_reader = mocker.patch("aiofiles.open")
|
||||
mock_reader.return_value.__aenter__.return_value = mock_data
|
||||
|
||||
# Prerequisites
|
||||
assert config.IFLYTEK_APP_ID
|
||||
assert config.IFLYTEK_API_KEY
|
||||
|
|
|
|||
|
|
@ -27,10 +27,13 @@ async def test_embedding(mocker):
|
|||
type(config.get_openai_llm()).proxy = mocker.PropertyMock(return_value="http://mock.proxy")
|
||||
|
||||
# Prerequisites
|
||||
assert config.get_openai_llm()
|
||||
assert config.get_openai_llm().proxy
|
||||
llm_config = config.get_openai_llm()
|
||||
assert llm_config
|
||||
assert llm_config.proxy
|
||||
|
||||
result = await oas3_openai_text_to_embedding("Panda emoji")
|
||||
result = await oas3_openai_text_to_embedding(
|
||||
"Panda emoji", openai_api_key=llm_config.api_key, proxy=llm_config.proxy
|
||||
)
|
||||
assert result
|
||||
assert result.model
|
||||
assert len(result.data) > 0
|
||||
|
|
|
|||
|
|
@ -39,10 +39,10 @@ async def test_draw(mocker):
|
|||
mocker.patch.object(S3, "cache", return_value="http://mock.s3.com/0.png")
|
||||
|
||||
# Prerequisites
|
||||
assert config.get_openai_llm()
|
||||
assert config.get_openai_llm().proxy
|
||||
llm_config = config.get_openai_llm()
|
||||
assert llm_config
|
||||
|
||||
binary_data = await oas3_openai_text_to_image("Panda emoji", llm=LLM())
|
||||
binary_data = await oas3_openai_text_to_image("Panda emoji", llm=LLM(llm_config=llm_config))
|
||||
assert binary_data
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue