mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-08 15:05:17 +02:00
fixbug: unit test
This commit is contained in:
parent
251352e802
commit
1523a0df81
19 changed files with 152 additions and 82 deletions
|
|
@ -35,7 +35,7 @@ class Action(SerializationMixin, ContextMixin, BaseModel):
|
|||
|
||||
@property
|
||||
def project_repo(self):
|
||||
return ProjectRepo(git_repo=self.context.git_repo)
|
||||
return ProjectRepo(self.context.git_repo)
|
||||
|
||||
@property
|
||||
def prompt_schema(self):
|
||||
|
|
|
|||
|
|
@ -55,10 +55,6 @@ class Context(BaseModel):
|
|||
|
||||
_llm: Optional[BaseLLM] = None
|
||||
|
||||
@property
|
||||
def file_repo(self):
|
||||
return self.git_repo.new_file_repository()
|
||||
|
||||
@property
|
||||
def options(self):
|
||||
"""Return all key-values"""
|
||||
|
|
|
|||
|
|
@ -30,8 +30,8 @@ async def text_to_image(text, size_type: str = "512x512", model_url="", config:
|
|||
|
||||
if model_url:
|
||||
binary_data = await oas3_metagpt_text_to_image(text, size_type, model_url)
|
||||
elif oai_llm := config.get_openai_llm():
|
||||
binary_data = await oas3_openai_text_to_image(text, size_type, LLM(oai_llm))
|
||||
elif config.get_openai_llm():
|
||||
binary_data = await oas3_openai_text_to_image(text, size_type, LLM())
|
||||
else:
|
||||
raise ValueError("Missing necessary parameters.")
|
||||
base64_data = base64.b64encode(binary_data).decode("utf-8")
|
||||
|
|
|
|||
|
|
@ -191,7 +191,7 @@ class Role(SerializationMixin, ContextMixin, BaseModel):
|
|||
|
||||
@property
|
||||
def project_repo(self) -> ProjectRepo:
|
||||
project_repo = ProjectRepo(git_repo=self.context.git_repo)
|
||||
project_repo = ProjectRepo(self.context.git_repo)
|
||||
return project_repo.with_src_path(self.context.src_workspace) if self.context.src_workspace else project_repo
|
||||
|
||||
@property
|
||||
|
|
|
|||
|
|
@ -78,12 +78,14 @@ class ResourceFileRepositories(FileRepository):
|
|||
|
||||
|
||||
class ProjectRepo(FileRepository):
|
||||
def __init__(self, root: str | Path = None, git_repo: GitRepository = None):
|
||||
if not root and not git_repo:
|
||||
raise ValueError("Invalid root and git_repo")
|
||||
git_repo_ = git_repo or GitRepository(local_path=Path(root))
|
||||
def __init__(self, root: str | Path | GitRepository):
|
||||
if isinstance(root, str) or isinstance(root, Path):
|
||||
git_repo_ = GitRepository(local_path=Path(root))
|
||||
elif isinstance(root, GitRepository):
|
||||
git_repo_ = root
|
||||
else:
|
||||
raise ValueError("Invalid root")
|
||||
super().__init__(git_repo=git_repo_, relative_path=Path("."))
|
||||
|
||||
self._git_repo = git_repo_
|
||||
self.docs = DocFileRepositories(self._git_repo)
|
||||
self.resources = ResourceFileRepositories(self._git_repo)
|
||||
|
|
|
|||
1
tests/data/openai/embedding.json
Normal file
1
tests/data/openai/embedding.json
Normal file
File diff suppressed because one or more lines are too long
|
|
@ -11,9 +11,9 @@ import uuid
|
|||
import pytest
|
||||
|
||||
from metagpt.actions.debug_error import DebugError
|
||||
from metagpt.const import TEST_CODES_FILE_REPO, TEST_OUTPUTS_FILE_REPO
|
||||
from metagpt.context import CONTEXT
|
||||
from metagpt.schema import RunCodeContext, RunCodeResult
|
||||
from metagpt.utils.project_repo import ProjectRepo
|
||||
|
||||
CODE_CONTENT = '''
|
||||
from typing import List
|
||||
|
|
@ -118,6 +118,7 @@ if __name__ == '__main__':
|
|||
@pytest.mark.asyncio
|
||||
async def test_debug_error():
|
||||
CONTEXT.src_workspace = CONTEXT.git_repo.workdir / uuid.uuid4().hex
|
||||
project_repo = ProjectRepo(CONTEXT.git_repo)
|
||||
ctx = RunCodeContext(
|
||||
code_filename="player.py",
|
||||
test_filename="test_player.py",
|
||||
|
|
@ -125,9 +126,8 @@ async def test_debug_error():
|
|||
output_filename="output.log",
|
||||
)
|
||||
|
||||
repo = CONTEXT.file_repo
|
||||
await repo.save_file(filename=ctx.code_filename, content=CODE_CONTENT, relative_path=CONTEXT.src_workspace)
|
||||
await repo.save_file(filename=ctx.test_filename, content=TEST_CONTENT, relative_path=TEST_CODES_FILE_REPO)
|
||||
await project_repo.with_src_path(CONTEXT.src_workspace).srcs.save(filename=ctx.code_filename, content=CODE_CONTENT)
|
||||
await project_repo.tests.save(filename=ctx.test_filename, content=TEST_CONTENT)
|
||||
output_data = RunCodeResult(
|
||||
stdout=";",
|
||||
stderr="",
|
||||
|
|
@ -141,9 +141,7 @@ async def test_debug_error():
|
|||
"----------------------------------------------------------------------\n"
|
||||
"Ran 5 tests in 0.007s\n\nFAILED (failures=1)\n;\n",
|
||||
)
|
||||
await repo.save_file(
|
||||
filename=ctx.output_filename, content=output_data.model_dump_json(), relative_path=TEST_OUTPUTS_FILE_REPO
|
||||
)
|
||||
await project_repo.test_outputs.save(filename=ctx.output_filename, content=output_data.model_dump_json())
|
||||
debug_error = DebugError(i_context=ctx)
|
||||
|
||||
rsp = await debug_error.run()
|
||||
|
|
|
|||
|
|
@ -9,18 +9,18 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.actions.design_api import WriteDesign
|
||||
from metagpt.const import PRDS_FILE_REPO
|
||||
from metagpt.context import CONTEXT
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.project_repo import ProjectRepo
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_design_api():
|
||||
inputs = ["我们需要一个音乐播放器,它应该有播放、暂停、上一曲、下一曲等功能。"] # PRD_SAMPLE
|
||||
repo = CONTEXT.file_repo
|
||||
project_repo = ProjectRepo(CONTEXT.git_repo)
|
||||
for prd in inputs:
|
||||
await repo.save_file("new_prd.txt", content=prd, relative_path=PRDS_FILE_REPO)
|
||||
await project_repo.docs.prd.save(filename="new_prd.txt", content=prd)
|
||||
|
||||
design_api = WriteDesign()
|
||||
|
||||
|
|
|
|||
|
|
@ -9,9 +9,10 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.actions.prepare_documents import PrepareDocuments
|
||||
from metagpt.const import DOCS_FILE_REPO, REQUIREMENT_FILENAME
|
||||
from metagpt.const import REQUIREMENT_FILENAME
|
||||
from metagpt.context import CONTEXT
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.project_repo import ProjectRepo
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -24,6 +25,6 @@ async def test_prepare_documents():
|
|||
|
||||
await PrepareDocuments(context=CONTEXT).run(with_messages=[msg])
|
||||
assert CONTEXT.git_repo
|
||||
doc = await CONTEXT.file_repo.get_file(filename=REQUIREMENT_FILENAME, relative_path=DOCS_FILE_REPO)
|
||||
doc = await ProjectRepo(CONTEXT.git_repo).docs.get(filename=REQUIREMENT_FILENAME)
|
||||
assert doc
|
||||
assert doc.content == msg.content
|
||||
|
|
|
|||
|
|
@ -9,17 +9,18 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.actions.project_management import WriteTasks
|
||||
from metagpt.const import PRDS_FILE_REPO, SYSTEM_DESIGN_FILE_REPO
|
||||
from metagpt.context import CONTEXT
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.project_repo import ProjectRepo
|
||||
from tests.metagpt.actions.mock_json import DESIGN, PRD
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_design_api():
|
||||
await CONTEXT.file_repo.save_file("1.txt", content=str(PRD), relative_path=PRDS_FILE_REPO)
|
||||
await CONTEXT.file_repo.save_file("1.txt", content=str(DESIGN), relative_path=SYSTEM_DESIGN_FILE_REPO)
|
||||
project_repo = ProjectRepo(CONTEXT.git_repo)
|
||||
await project_repo.docs.prd.save("1.txt", content=str(PRD))
|
||||
await project_repo.docs.system_design.save("1.txt", content=str(DESIGN))
|
||||
logger.info(CONTEXT.git_repo)
|
||||
|
||||
action = WriteTasks()
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ from metagpt.context import CONTEXT
|
|||
from metagpt.llm import LLM
|
||||
from metagpt.utils.common import aread
|
||||
from metagpt.utils.git_repository import ChangeType
|
||||
from metagpt.utils.project_repo import ProjectRepo
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -22,12 +23,8 @@ async def test_rebuild():
|
|||
# Mock
|
||||
data = await aread(filename=Path(__file__).parent / "../../data/graph_db/networkx.json")
|
||||
graph_db_filename = Path(CONTEXT.git_repo.workdir.name).with_suffix(".json")
|
||||
repo = CONTEXT.file_repo
|
||||
await repo.save_file(
|
||||
filename=str(graph_db_filename),
|
||||
relative_path=GRAPH_REPO_FILE_REPO,
|
||||
content=data,
|
||||
)
|
||||
project_repo = ProjectRepo(CONTEXT.git_repo)
|
||||
await project_repo.docs.graph_repo.save(filename=str(graph_db_filename), content=data)
|
||||
CONTEXT.git_repo.add_change({f"{GRAPH_REPO_FILE_REPO}/{graph_db_filename}": ChangeType.UNTRACTED})
|
||||
CONTEXT.git_repo.commit("commit1")
|
||||
|
||||
|
|
@ -35,8 +32,7 @@ async def test_rebuild():
|
|||
name="RedBean", i_context=str(Path(__file__).parent.parent.parent.parent / "metagpt"), llm=LLM()
|
||||
)
|
||||
await action.run()
|
||||
graph_file_repo = CONTEXT.git_repo.new_file_repository(relative_path=GRAPH_REPO_FILE_REPO)
|
||||
assert graph_file_repo.changed_files
|
||||
assert project_repo.docs.graph_repo.changed_files
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
|
|
|||
|
|
@ -9,10 +9,10 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.actions.summarize_code import SummarizeCode
|
||||
from metagpt.const import SYSTEM_DESIGN_FILE_REPO, TASK_FILE_REPO
|
||||
from metagpt.context import CONTEXT
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import CodeSummarizeContext
|
||||
from metagpt.utils.project_repo import ProjectRepo
|
||||
|
||||
DESIGN_CONTENT = """
|
||||
{"Implementation approach": "To develop this snake game, we will use the Python language and choose the Pygame library. Pygame is an open-source Python module collection specifically designed for writing video games. It provides functionalities such as displaying images and playing sounds, making it suitable for creating intuitive and responsive user interfaces. We will ensure efficient game logic to prevent any delays during gameplay. The scoring system will be simple, with the snake gaining points for each food it eats. We will use Pygame's event handling system to implement pause and resume functionality, as well as high-score tracking. The difficulty will increase by speeding up the snake's movement. In the initial version, we will focus on single-player mode and consider adding multiplayer mode and customizable skins in future updates. Based on the new requirement, we will also add a moving obstacle that appears randomly. If the snake eats this obstacle, the game will end. If the snake does not eat the obstacle, it will disappear after 5 seconds. For this, we need to add mechanisms for obstacle generation, movement, and disappearance in the game logic.", "Project_name": "snake_game", "File list": ["main.py", "game.py", "snake.py", "food.py", "obstacle.py", "scoreboard.py", "constants.py", "assets/styles.css", "assets/index.html"], "Data structures and interfaces": "```mermaid\n classDiagram\n class Game{\n +int score\n +int speed\n +bool game_over\n +bool paused\n +Snake snake\n +Food food\n +Obstacle obstacle\n +Scoreboard scoreboard\n +start_game() void\n +pause_game() void\n +resume_game() void\n +end_game() void\n +increase_difficulty() void\n +update() void\n +render() void\n Game()\n }\n class Snake{\n +list body_parts\n +str direction\n +bool grow\n +move() void\n +grow() void\n +check_collision() bool\n Snake()\n }\n class Food{\n +tuple position\n +spawn() void\n Food()\n }\n class Obstacle{\n +tuple position\n +int lifetime\n +bool active\n +spawn() void\n +move() void\n +check_collision() bool\n +disappear() void\n Obstacle()\n }\n class Scoreboard{\n +int high_score\n +update_score(int) void\n +reset_score() void\n +load_high_score() void\n +save_high_score() void\n Scoreboard()\n }\n class Constants{\n }\n Game \"1\" -- \"1\" Snake: has\n Game \"1\" -- \"1\" Food: has\n Game \"1\" -- \"1\" Obstacle: has\n Game \"1\" -- \"1\" Scoreboard: has\n ```", "Program call flow": "```sequenceDiagram\n participant M as Main\n participant G as Game\n participant S as Snake\n participant F as Food\n participant O as Obstacle\n participant SB as Scoreboard\n M->>G: start_game()\n loop game loop\n G->>S: move()\n G->>S: check_collision()\n G->>F: spawn()\n G->>O: spawn()\n G->>O: move()\n G->>O: check_collision()\n G->>O: disappear()\n G->>SB: update_score(score)\n G->>G: update()\n G->>G: render()\n alt if paused\n M->>G: pause_game()\n M->>G: resume_game()\n end\n alt if game_over\n G->>M: end_game()\n end\n end\n```", "Anything UNCLEAR": "There is no need for further clarification as the requirements are already clear."}
|
||||
|
|
@ -178,17 +178,22 @@ class Snake:
|
|||
@pytest.mark.asyncio
|
||||
async def test_summarize_code():
|
||||
CONTEXT.src_workspace = CONTEXT.git_repo.workdir / "src"
|
||||
await CONTEXT.file_repo.save_file(filename="1.json", relative_path=SYSTEM_DESIGN_FILE_REPO, content=DESIGN_CONTENT)
|
||||
await CONTEXT.file_repo.save_file(filename="1.json", relative_path=TASK_FILE_REPO, content=TASK_CONTENT)
|
||||
await CONTEXT.file_repo.save_file(filename="food.py", relative_path=CONTEXT.src_workspace, content=FOOD_PY)
|
||||
await CONTEXT.file_repo.save_file(filename="game.py", relative_path=CONTEXT.src_workspace, content=GAME_PY)
|
||||
await CONTEXT.file_repo.save_file(filename="main.py", relative_path=CONTEXT.src_workspace, content=MAIN_PY)
|
||||
await CONTEXT.file_repo.save_file(filename="snake.py", relative_path=CONTEXT.src_workspace, content=SNAKE_PY)
|
||||
project_repo = ProjectRepo(CONTEXT.git_repo)
|
||||
await project_repo.docs.system_design.save(filename="1.json", content=DESIGN_CONTENT)
|
||||
await project_repo.docs.task.save(filename="1.json", content=TASK_CONTENT)
|
||||
await project_repo.with_src_path(CONTEXT.src_workspace).srcs.save(filename="food.py", content=FOOD_PY)
|
||||
assert project_repo.srcs.workdir == CONTEXT.src_workspace
|
||||
await project_repo.srcs.save(filename="game.py", content=GAME_PY)
|
||||
await project_repo.srcs.save(filename="main.py", content=MAIN_PY)
|
||||
await project_repo.srcs.save(filename="snake.py", content=SNAKE_PY)
|
||||
|
||||
src_file_repo = CONTEXT.git_repo.new_file_repository(relative_path=CONTEXT.src_workspace)
|
||||
all_files = src_file_repo.all_files
|
||||
all_files = project_repo.srcs.all_files
|
||||
ctx = CodeSummarizeContext(design_filename="1.json", task_filename="1.json", codes_filenames=all_files)
|
||||
action = SummarizeCode(i_context=ctx)
|
||||
rsp = await action.run()
|
||||
assert rsp
|
||||
logger.info(rsp)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
|
|
@ -12,26 +12,24 @@ from pathlib import Path
|
|||
import pytest
|
||||
|
||||
from metagpt.actions.write_code import WriteCode
|
||||
from metagpt.const import (
|
||||
CODE_SUMMARIES_FILE_REPO,
|
||||
SYSTEM_DESIGN_FILE_REPO,
|
||||
TASK_FILE_REPO,
|
||||
TEST_OUTPUTS_FILE_REPO,
|
||||
)
|
||||
from metagpt.context import CONTEXT
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import CodingContext, Document
|
||||
from metagpt.utils.common import aread
|
||||
from metagpt.utils.project_repo import ProjectRepo
|
||||
from tests.metagpt.actions.mock_markdown import TASKS_2, WRITE_CODE_PROMPT_SAMPLE
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_code():
|
||||
ccontext = CodingContext(
|
||||
# Prerequisites
|
||||
CONTEXT.src_workspace = CONTEXT.git_repo.workdir / "writecode"
|
||||
|
||||
coding_ctx = CodingContext(
|
||||
filename="task_filename.py", design_doc=Document(content="设计一个名为'add'的函数,该函数接受两个整数作为输入,并返回它们的和。")
|
||||
)
|
||||
doc = Document(content=ccontext.model_dump_json())
|
||||
doc = Document(content=coding_ctx.model_dump_json())
|
||||
write_code = WriteCode(i_context=doc)
|
||||
|
||||
code = await write_code.run()
|
||||
|
|
@ -55,33 +53,28 @@ async def test_write_code_deps():
|
|||
# Prerequisites
|
||||
CONTEXT.src_workspace = CONTEXT.git_repo.workdir / "snake1/snake1"
|
||||
demo_path = Path(__file__).parent / "../../data/demo_project"
|
||||
await CONTEXT.file_repo.save_file(
|
||||
filename="test_game.py.json",
|
||||
content=await aread(str(demo_path / "test_game.py.json")),
|
||||
relative_path=TEST_OUTPUTS_FILE_REPO,
|
||||
project_repo = ProjectRepo(CONTEXT.git_repo)
|
||||
await project_repo.test_outputs.save(
|
||||
filename="test_game.py.json", content=await aread(str(demo_path / "test_game.py.json"))
|
||||
)
|
||||
await CONTEXT.file_repo.save_file(
|
||||
await project_repo.docs.code_summary.save(
|
||||
filename="20231221155954.json",
|
||||
content=await aread(str(demo_path / "code_summaries.json")),
|
||||
relative_path=CODE_SUMMARIES_FILE_REPO,
|
||||
)
|
||||
await CONTEXT.file_repo.save_file(
|
||||
await project_repo.docs.system_design.save(
|
||||
filename="20231221155954.json",
|
||||
content=await aread(str(demo_path / "system_design.json")),
|
||||
relative_path=SYSTEM_DESIGN_FILE_REPO,
|
||||
)
|
||||
await CONTEXT.file_repo.save_file(
|
||||
filename="20231221155954.json", content=await aread(str(demo_path / "tasks.json")), relative_path=TASK_FILE_REPO
|
||||
await project_repo.docs.task.save(
|
||||
filename="20231221155954.json", content=await aread(str(demo_path / "tasks.json"))
|
||||
)
|
||||
await CONTEXT.file_repo.save_file(
|
||||
filename="main.py", content='if __name__ == "__main__":\nmain()', relative_path=CONTEXT.src_workspace
|
||||
await project_repo.with_src_path(CONTEXT.src_workspace).srcs.save(
|
||||
filename="main.py", content='if __name__ == "__main__":\nmain()'
|
||||
)
|
||||
ccontext = CodingContext(
|
||||
filename="game.py",
|
||||
design_doc=await CONTEXT.file_repo.get_file(
|
||||
filename="20231221155954.json", relative_path=SYSTEM_DESIGN_FILE_REPO
|
||||
),
|
||||
task_doc=await CONTEXT.file_repo.get_file(filename="20231221155954.json", relative_path=TASK_FILE_REPO),
|
||||
design_doc=await project_repo.docs.system_design.get(filename="20231221155954.json"),
|
||||
task_doc=await project_repo.docs.task.get(filename="20231221155954.json"),
|
||||
code_doc=Document(filename="game.py", content="", root_path="snake1"),
|
||||
)
|
||||
coding_doc = Document(root_path="snake1", filename="game.py", content=ccontext.json())
|
||||
|
|
|
|||
|
|
@ -9,21 +9,22 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.actions import UserRequirement, WritePRD
|
||||
from metagpt.const import DOCS_FILE_REPO, PRDS_FILE_REPO, REQUIREMENT_FILENAME
|
||||
from metagpt.const import REQUIREMENT_FILENAME
|
||||
from metagpt.context import CONTEXT
|
||||
from metagpt.logs import logger
|
||||
from metagpt.roles.product_manager import ProductManager
|
||||
from metagpt.roles.role import RoleReactMode
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.common import any_to_str
|
||||
from metagpt.utils.project_repo import ProjectRepo
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_prd(new_filename):
|
||||
product_manager = ProductManager()
|
||||
requirements = "开发一个基于大语言模型与私有知识库的搜索引擎,希望可以基于大语言模型进行搜索总结"
|
||||
repo = CONTEXT.file_repo
|
||||
await repo.save_file(filename=REQUIREMENT_FILENAME, content=requirements, relative_path=DOCS_FILE_REPO)
|
||||
project_repo = ProjectRepo(CONTEXT.git_repo)
|
||||
await project_repo.docs.save(filename=REQUIREMENT_FILENAME, content=requirements)
|
||||
product_manager.rc.react_mode = RoleReactMode.BY_ORDER
|
||||
prd = await product_manager.run(Message(content=requirements, cause_by=UserRequirement))
|
||||
assert prd.cause_by == any_to_str(WritePRD)
|
||||
|
|
@ -33,7 +34,7 @@ async def test_write_prd(new_filename):
|
|||
# Assert the prd is not None or empty
|
||||
assert prd is not None
|
||||
assert prd.content != ""
|
||||
assert CONTEXT.git_repo.new_file_repository(relative_path=PRDS_FILE_REPO).changed_files
|
||||
assert ProjectRepo(product_manager.context.git_repo).docs.prd.changed_files
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -6,17 +6,30 @@
|
|||
@File : test_text_to_embedding.py
|
||||
@Desc : Unit tests.
|
||||
"""
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.learn.text_to_embedding import text_to_embedding
|
||||
from metagpt.utils.common import aread
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_to_embedding():
|
||||
async def test_text_to_embedding(mocker):
|
||||
# mock
|
||||
mock_post = mocker.patch("aiohttp.ClientSession.post")
|
||||
mock_response = mocker.AsyncMock()
|
||||
mock_response.status = 200
|
||||
data = await aread(Path(__file__).parent / "../../data/openai/embedding.json")
|
||||
mock_response.json.return_value = json.loads(data)
|
||||
mock_post.return_value.__aenter__.return_value = mock_response
|
||||
type(config.get_openai_llm()).proxy = mocker.PropertyMock(return_value="http://mock.proxy")
|
||||
|
||||
# Prerequisites
|
||||
assert config.get_openai_llm()
|
||||
assert config.get_openai_llm().proxy
|
||||
|
||||
v = await text_to_embedding(text="Panda emoji")
|
||||
assert len(v.data) > 0
|
||||
|
|
|
|||
|
|
@ -6,9 +6,11 @@
|
|||
@File : test_text_to_image.py
|
||||
@Desc : Unit tests.
|
||||
"""
|
||||
import base64
|
||||
|
||||
|
||||
import openai
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from metagpt.config2 import Config
|
||||
from metagpt.learn.text_to_image import text_to_image
|
||||
|
|
@ -34,7 +36,23 @@ async def test_text_to_image(mocker):
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_text_to_image():
|
||||
async def test_openai_text_to_image(mocker):
|
||||
# mocker
|
||||
mock_url = mocker.Mock()
|
||||
mock_url.url.return_value = "http://mock.com/0.png"
|
||||
|
||||
class _MockData(BaseModel):
|
||||
data: list
|
||||
|
||||
mock_data = _MockData(data=[mock_url])
|
||||
mocker.patch.object(openai.resources.images.AsyncImages, "generate", return_value=mock_data)
|
||||
mock_post = mocker.patch("aiohttp.ClientSession.get")
|
||||
mock_response = mocker.AsyncMock()
|
||||
mock_response.status = 200
|
||||
mock_response.read.return_value = base64.b64encode(b"success")
|
||||
mock_post.return_value.__aenter__.return_value = mock_response
|
||||
mocker.patch.object(S3, "cache", return_value="http://mock.s3.com/0.png")
|
||||
|
||||
config = Config.default()
|
||||
assert config.get_openai_llm()
|
||||
|
||||
|
|
|
|||
|
|
@ -7,21 +7,31 @@
|
|||
@Modified By: mashenquan, 2023-8-9, add more text formatting options
|
||||
@Modified By: mashenquan, 2023-8-17, move to `tools` folder.
|
||||
"""
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from azure.cognitiveservices.speech import ResultReason
|
||||
from azure.cognitiveservices.speech import ResultReason, SpeechSynthesizer
|
||||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.tools.azure_tts import AzureTTS
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_azure_tts():
|
||||
async def test_azure_tts(mocker):
|
||||
# mock
|
||||
mock_result = mocker.Mock()
|
||||
mock_result.audio_data = b"mock audio data"
|
||||
mock_result.reason = ResultReason.SynthesizingAudioCompleted
|
||||
mock_data = mocker.Mock()
|
||||
mock_data.get.return_value = mock_result
|
||||
mocker.patch.object(SpeechSynthesizer, "speak_ssml_async", return_value=mock_data)
|
||||
mocker.patch.object(Path, "exists", return_value=True)
|
||||
|
||||
# Prerequisites
|
||||
assert config.AZURE_TTS_SUBSCRIPTION_KEY and config.AZURE_TTS_SUBSCRIPTION_KEY != "YOUR_API_KEY"
|
||||
assert config.AZURE_TTS_REGION
|
||||
|
||||
azure_tts = AzureTTS(subscription_key="", region="")
|
||||
azure_tts = AzureTTS(subscription_key=config.AZURE_TTS_SUBSCRIPTION_KEY, region=config.AZURE_TTS_REGION)
|
||||
text = """
|
||||
女儿看见父亲走了进来,问道:
|
||||
<mstts:express-as role="YoungAdultFemale" style="calm">
|
||||
|
|
|
|||
|
|
@ -5,17 +5,30 @@
|
|||
@Author : mashenquan
|
||||
@File : test_openai_text_to_embedding.py
|
||||
"""
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.tools.openai_text_to_embedding import oas3_openai_text_to_embedding
|
||||
from metagpt.utils.common import aread
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_embedding():
|
||||
async def test_embedding(mocker):
|
||||
# mock
|
||||
mock_post = mocker.patch("aiohttp.ClientSession.post")
|
||||
mock_response = mocker.AsyncMock()
|
||||
mock_response.status = 200
|
||||
data = await aread(Path(__file__).parent / "../../data/openai/embedding.json")
|
||||
mock_response.json.return_value = json.loads(data)
|
||||
mock_post.return_value.__aenter__.return_value = mock_response
|
||||
type(config.get_openai_llm()).proxy = mocker.PropertyMock(return_value="http://mock.proxy")
|
||||
|
||||
# Prerequisites
|
||||
assert config.get_openai_llm()
|
||||
assert config.get_openai_llm().proxy
|
||||
|
||||
result = await oas3_openai_text_to_embedding("Panda emoji")
|
||||
assert result
|
||||
|
|
|
|||
|
|
@ -5,22 +5,44 @@
|
|||
@Author : mashenquan
|
||||
@File : test_openai_text_to_image.py
|
||||
"""
|
||||
import base64
|
||||
|
||||
import openai
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.tools.openai_text_to_image import (
|
||||
OpenAIText2Image,
|
||||
oas3_openai_text_to_image,
|
||||
)
|
||||
from metagpt.utils.s3 import S3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_draw():
|
||||
async def test_draw(mocker):
|
||||
# mock
|
||||
mock_url = mocker.Mock()
|
||||
mock_url.url.return_value = "http://mock.com/0.png"
|
||||
|
||||
class _MockData(BaseModel):
|
||||
data: list
|
||||
|
||||
mock_data = _MockData(data=[mock_url])
|
||||
mocker.patch.object(openai.resources.images.AsyncImages, "generate", return_value=mock_data)
|
||||
mock_post = mocker.patch("aiohttp.ClientSession.get")
|
||||
mock_response = mocker.AsyncMock()
|
||||
mock_response.status = 200
|
||||
mock_response.read.return_value = base64.b64encode(b"success")
|
||||
mock_post.return_value.__aenter__.return_value = mock_response
|
||||
mocker.patch.object(S3, "cache", return_value="http://mock.s3.com/0.png")
|
||||
|
||||
# Prerequisites
|
||||
assert config.get_openai_llm()
|
||||
assert config.get_openai_llm().proxy
|
||||
|
||||
binary_data = await oas3_openai_text_to_image("Panda emoji")
|
||||
binary_data = await oas3_openai_text_to_image("Panda emoji", llm=LLM())
|
||||
assert binary_data
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue