mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-17 15:35:21 +02:00
Merge branch 'main' into feature-openai-v1
This commit is contained in:
commit
9a4f0d555c
260 changed files with 10576 additions and 3191 deletions
|
|
@ -6,14 +6,18 @@
|
|||
@File : conftest.py
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import DEFAULT_WORKSPACE_ROOT
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.openai_api import OpenAIGPTAPI as GPTAPI
|
||||
import asyncio
|
||||
import re
|
||||
from metagpt.utils.git_repository import GitRepository
|
||||
|
||||
|
||||
class Context:
|
||||
|
|
@ -68,3 +72,27 @@ def proxy():
|
|||
|
||||
server = asyncio.get_event_loop().run_until_complete(asyncio.start_server(handle_client, "127.0.0.1", 0))
|
||||
return "http://{}:{}".format(*server.sockets[0].getsockname())
|
||||
|
||||
|
||||
# see https://github.com/Delgan/loguru/issues/59#issuecomment-466591978
|
||||
@pytest.fixture
|
||||
def loguru_caplog(caplog):
|
||||
class PropogateHandler(logging.Handler):
|
||||
def emit(self, record):
|
||||
logging.getLogger(record.name).handle(record)
|
||||
|
||||
logger.add(PropogateHandler(), format="{message}")
|
||||
yield caplog
|
||||
|
||||
|
||||
# init & dispose git repo
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def setup_and_teardown_git_repo(request):
|
||||
CONFIG.git_repo = GitRepository(local_path=DEFAULT_WORKSPACE_ROOT / "unittest")
|
||||
|
||||
# Destroy git repo at the end of the test session.
|
||||
def fin():
|
||||
CONFIG.git_repo.delete_repository()
|
||||
|
||||
# Register the function for destroying the environment.
|
||||
request.addfinalizer(fin)
|
||||
|
|
|
|||
|
|
@ -90,7 +90,7 @@ Python's in-built data structures like lists and dictionaries will be used exten
|
|||
|
||||
For testing, we can use the PyTest framework. This is a mature full-featured Python testing tool that helps you write better programs.
|
||||
|
||||
## Python package name:
|
||||
## Project Name:
|
||||
```python
|
||||
"adventure_game"
|
||||
```
|
||||
|
|
@ -100,7 +100,7 @@ For testing, we can use the PyTest framework. This is a mature full-featured Pyt
|
|||
file_list = ["main.py", "room.py", "player.py", "game.py", "object.py", "puzzle.py", "test_game.py"]
|
||||
```
|
||||
|
||||
## Data structures and interface definitions:
|
||||
## Data structures and interfaces:
|
||||
```mermaid
|
||||
classDiagram
|
||||
class Room{
|
||||
|
|
@ -209,7 +209,7 @@ Shared knowledge for this project includes understanding the basic principles of
|
|||
"""
|
||||
```
|
||||
|
||||
## Anything UNCLEAR: Provide as Plain text. Make clear here. For example, don't forget a main entry. don't forget to init 3rd party libs.
|
||||
## Anything UNCLEAR: Provide as Plain text. Try to clarify it. For example, don't forget a main entry. don't forget to init 3rd party libs.
|
||||
```python
|
||||
"""
|
||||
The original requirements did not specify whether the game should have a save/load feature, multiplayer support, or any specific graphical user interface. More information on these aspects could help in further refining the product design and requirements.
|
||||
|
|
@ -311,12 +311,10 @@ TASKS = [
|
|||
"添加数据API:接受用户输入的文档库,对文档库进行索引\n- 使用MeiliSearch连接并添加文档库",
|
||||
"搜索API:接收用户输入的关键词,返回相关的搜索结果\n- 使用MeiliSearch连接并使用接口获得对应数据",
|
||||
"多条件筛选API:接收用户选择的筛选条件,返回符合条件的搜索结果。\n- 使用MeiliSearch进行筛选并返回符合条件的搜索结果",
|
||||
"智能推荐API:根据用户的搜索历史记录和搜索行为,推荐相关的搜索结果。"
|
||||
"智能推荐API:根据用户的搜索历史记录和搜索行为,推荐相关的搜索结果。",
|
||||
]
|
||||
|
||||
TASKS_2 = [
|
||||
"完成main.py的功能"
|
||||
]
|
||||
TASKS_2 = ["完成main.py的功能"]
|
||||
|
||||
SEARCH_CODE_SAMPLE = """
|
||||
import requests
|
||||
|
|
@ -460,7 +458,7 @@ if __name__ == '__main__':
|
|||
print('No results found.')
|
||||
'''
|
||||
|
||||
MEILI_CODE = '''import meilisearch
|
||||
MEILI_CODE = """import meilisearch
|
||||
from typing import List
|
||||
|
||||
|
||||
|
|
@ -496,9 +494,9 @@ if __name__ == '__main__':
|
|||
|
||||
# 添加文档库到搜索引擎
|
||||
search_engine.add_documents(books_data_source, documents)
|
||||
'''
|
||||
"""
|
||||
|
||||
MEILI_ERROR = '''/usr/local/bin/python3.9 /Users/alexanderwu/git/metagpt/examples/search/meilisearch_index.py
|
||||
MEILI_ERROR = """/usr/local/bin/python3.9 /Users/alexanderwu/git/metagpt/examples/search/meilisearch_index.py
|
||||
Traceback (most recent call last):
|
||||
File "/Users/alexanderwu/git/metagpt/examples/search/meilisearch_index.py", line 44, in <module>
|
||||
search_engine.add_documents(books_data_source, documents)
|
||||
|
|
@ -506,7 +504,7 @@ Traceback (most recent call last):
|
|||
index = self.client.get_or_create_index(index_name)
|
||||
AttributeError: 'Client' object has no attribute 'get_or_create_index'
|
||||
|
||||
Process finished with exit code 1'''
|
||||
Process finished with exit code 1"""
|
||||
|
||||
MEILI_CODE_REFINED = """
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -7,20 +7,23 @@
|
|||
"""
|
||||
from typing import List, Tuple
|
||||
|
||||
from metagpt.actions import ActionOutput
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
|
||||
t_dict = {"Required Python third-party packages": "\"\"\"\nflask==1.1.2\npygame==2.0.1\n\"\"\"\n",
|
||||
"Required Other language third-party packages": "\"\"\"\nNo third-party packages required for other languages.\n\"\"\"\n",
|
||||
"Full API spec": "\"\"\"\nopenapi: 3.0.0\ninfo:\n title: Web Snake Game API\n version: 1.0.0\npaths:\n /game:\n get:\n summary: Get the current game state\n responses:\n '200':\n description: A JSON object of the game state\n post:\n summary: Send a command to the game\n requestBody:\n required: true\n content:\n application/json:\n schema:\n type: object\n properties:\n command:\n type: string\n responses:\n '200':\n description: A JSON object of the updated game state\n\"\"\"\n",
|
||||
"Logic Analysis": [
|
||||
["app.py", "Main entry point for the Flask application. Handles HTTP requests and responses."],
|
||||
["game.py", "Contains the Game and Snake classes. Handles the game logic."],
|
||||
["static/js/script.js", "Handles user interactions and updates the game UI."],
|
||||
["static/css/styles.css", "Defines the styles for the game UI."],
|
||||
["templates/index.html", "The main page of the web application. Displays the game UI."]],
|
||||
"Task list": ["game.py", "app.py", "static/css/styles.css", "static/js/script.js", "templates/index.html"],
|
||||
"Shared Knowledge": "\"\"\"\n'game.py' contains the Game and Snake classes which are responsible for the game logic. The Game class uses an instance of the Snake class.\n\n'app.py' is the main entry point for the Flask application. It creates an instance of the Game class and handles HTTP requests and responses.\n\n'static/js/script.js' is responsible for handling user interactions and updating the game UI based on the game state returned by 'app.py'.\n\n'static/css/styles.css' defines the styles for the game UI.\n\n'templates/index.html' is the main page of the web application. It displays the game UI and loads 'static/js/script.js' and 'static/css/styles.css'.\n\"\"\"\n",
|
||||
"Anything UNCLEAR": "We need clarification on how the high score should be stored. Should it persist across sessions (stored in a database or a file) or should it reset every time the game is restarted? Also, should the game speed increase as the snake grows, or should it remain constant throughout the game?"}
|
||||
t_dict = {
|
||||
"Required Python third-party packages": '"""\nflask==1.1.2\npygame==2.0.1\n"""\n',
|
||||
"Required Other language third-party packages": '"""\nNo third-party packages required for other languages.\n"""\n',
|
||||
"Full API spec": '"""\nopenapi: 3.0.0\ninfo:\n title: Web Snake Game API\n version: 1.0.0\npaths:\n /game:\n get:\n summary: Get the current game state\n responses:\n \'200\':\n description: A JSON object of the game state\n post:\n summary: Send a command to the game\n requestBody:\n required: true\n content:\n application/json:\n schema:\n type: object\n properties:\n command:\n type: string\n responses:\n \'200\':\n description: A JSON object of the updated game state\n"""\n',
|
||||
"Logic Analysis": [
|
||||
["app.py", "Main entry point for the Flask application. Handles HTTP requests and responses."],
|
||||
["game.py", "Contains the Game and Snake classes. Handles the game logic."],
|
||||
["static/js/script.js", "Handles user interactions and updates the game UI."],
|
||||
["static/css/styles.css", "Defines the styles for the game UI."],
|
||||
["templates/index.html", "The main page of the web application. Displays the game UI."],
|
||||
],
|
||||
"Task list": ["game.py", "app.py", "static/css/styles.css", "static/js/script.js", "templates/index.html"],
|
||||
"Shared Knowledge": "\"\"\"\n'game.py' contains the Game and Snake classes which are responsible for the game logic. The Game class uses an instance of the Snake class.\n\n'app.py' is the main entry point for the Flask application. It creates an instance of the Game class and handles HTTP requests and responses.\n\n'static/js/script.js' is responsible for handling user interactions and updating the game UI based on the game state returned by 'app.py'.\n\n'static/css/styles.css' defines the styles for the game UI.\n\n'templates/index.html' is the main page of the web application. It displays the game UI and loads 'static/js/script.js' and 'static/css/styles.css'.\n\"\"\"\n",
|
||||
"Anything UNCLEAR": "We need clarification on how the high score should be stored. Should it persist across sessions (stored in a database or a file) or should it reset every time the game is restarted? Also, should the game speed increase as the snake grows, or should it remain constant throughout the game?",
|
||||
}
|
||||
|
||||
WRITE_TASKS_OUTPUT_MAPPING = {
|
||||
"Required Python third-party packages": (str, ...),
|
||||
|
|
@ -34,17 +37,17 @@ WRITE_TASKS_OUTPUT_MAPPING = {
|
|||
|
||||
|
||||
def test_create_model_class():
|
||||
test_class = ActionOutput.create_model_class("test_class", WRITE_TASKS_OUTPUT_MAPPING)
|
||||
test_class = ActionNode.create_model_class("test_class", WRITE_TASKS_OUTPUT_MAPPING)
|
||||
assert test_class.__name__ == "test_class"
|
||||
|
||||
|
||||
def test_create_model_class_with_mapping():
|
||||
t = ActionOutput.create_model_class("test_class_1", WRITE_TASKS_OUTPUT_MAPPING)
|
||||
t = ActionNode.create_model_class("test_class_1", WRITE_TASKS_OUTPUT_MAPPING)
|
||||
t1 = t(**t_dict)
|
||||
value = t1.dict()["Task list"]
|
||||
assert value == ["game.py", "app.py", "static/css/styles.css", "static/js/script.js", "templates/index.html"]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
test_create_model_class()
|
||||
test_create_model_class_with_mapping()
|
||||
|
|
|
|||
|
|
@ -5,17 +5,12 @@
|
|||
@Author : alexanderwu
|
||||
@File : test_azure_tts.py
|
||||
"""
|
||||
from metagpt.actions.azure_tts import AzureTTS
|
||||
from metagpt.tools.azure_tts import AzureTTS
|
||||
|
||||
|
||||
def test_azure_tts():
|
||||
azure_tts = AzureTTS("azure_tts")
|
||||
azure_tts.synthesize_speech(
|
||||
"zh-CN",
|
||||
"zh-CN-YunxiNeural",
|
||||
"Boy",
|
||||
"你好,我是卡卡",
|
||||
"output.wav")
|
||||
azure_tts = AzureTTS()
|
||||
azure_tts.synthesize_speech("zh-CN", "zh-CN-YunxiNeural", "Boy", "你好,我是卡卡", "output.wav")
|
||||
|
||||
# 运行需要先配置 SUBSCRIPTION_KEY
|
||||
# TODO: 这里如果要检验,还要额外加上对应的asr,才能确保前后生成是接近一致的,但现在还没有
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ import pytest
|
|||
|
||||
from metagpt.actions.clone_function import CloneFunction, run_function_code
|
||||
|
||||
|
||||
source_code = """
|
||||
import pandas as pd
|
||||
import ta
|
||||
|
|
@ -31,14 +30,18 @@ def get_expected_res():
|
|||
import ta
|
||||
|
||||
# 读取股票数据
|
||||
stock_data = pd.read_csv('./tests/data/baba_stock.csv')
|
||||
stock_data = pd.read_csv("./tests/data/baba_stock.csv")
|
||||
stock_data.head()
|
||||
# 计算简单移动平均线
|
||||
stock_data['SMA'] = ta.trend.sma_indicator(stock_data['Close'], window=6)
|
||||
stock_data[['Date', 'Close', 'SMA']].head()
|
||||
stock_data["SMA"] = ta.trend.sma_indicator(stock_data["Close"], window=6)
|
||||
stock_data[["Date", "Close", "SMA"]].head()
|
||||
# 计算布林带
|
||||
stock_data['bb_upper'], stock_data['bb_middle'], stock_data['bb_lower'] = ta.volatility.bollinger_hband_indicator(stock_data['Close'], window=20), ta.volatility.bollinger_mavg(stock_data['Close'], window=20), ta.volatility.bollinger_lband_indicator(stock_data['Close'], window=20)
|
||||
stock_data[['Date', 'Close', 'bb_upper', 'bb_middle', 'bb_lower']].head()
|
||||
stock_data["bb_upper"], stock_data["bb_middle"], stock_data["bb_lower"] = (
|
||||
ta.volatility.bollinger_hband_indicator(stock_data["Close"], window=20),
|
||||
ta.volatility.bollinger_mavg(stock_data["Close"], window=20),
|
||||
ta.volatility.bollinger_lband_indicator(stock_data["Close"], window=20),
|
||||
)
|
||||
stock_data[["Date", "Close", "bb_upper", "bb_middle", "bb_lower"]].head()
|
||||
return stock_data
|
||||
|
||||
|
||||
|
|
@ -46,9 +49,9 @@ def get_expected_res():
|
|||
async def test_clone_function():
|
||||
clone = CloneFunction()
|
||||
code = await clone.run(template_code, source_code)
|
||||
assert 'def ' in code
|
||||
stock_path = './tests/data/baba_stock.csv'
|
||||
df, msg = run_function_code(code, 'stock_indicator', stock_path)
|
||||
assert "def " in code
|
||||
stock_path = "./tests/data/baba_stock.csv"
|
||||
df, msg = run_function_code(code, "stock_indicator", stock_path)
|
||||
assert not msg
|
||||
expected_df = get_expected_res()
|
||||
assert df.equals(expected_df)
|
||||
|
|
|
|||
|
|
@ -4,17 +4,19 @@
|
|||
@Time : 2023/5/11 17:46
|
||||
@Author : alexanderwu
|
||||
@File : test_debug_error.py
|
||||
@Modifiled By: mashenquan, 2023-12-6. According to RFC 135
|
||||
"""
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.actions.debug_error import DebugError
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import TEST_CODES_FILE_REPO, TEST_OUTPUTS_FILE_REPO
|
||||
from metagpt.schema import RunCodeContext, RunCodeResult
|
||||
from metagpt.utils.file_repository import FileRepository
|
||||
|
||||
EXAMPLE_MSG_CONTENT = '''
|
||||
---
|
||||
## Development Code File Name
|
||||
player.py
|
||||
## Development Code
|
||||
```python
|
||||
CODE_CONTENT = '''
|
||||
from typing import List
|
||||
from deck import Deck
|
||||
from card import Card
|
||||
|
|
@ -58,12 +60,9 @@ class Player:
|
|||
if self.score > 21 and any(card.rank == 'A' for card in self.hand):
|
||||
self.score -= 10
|
||||
return self.score
|
||||
'''
|
||||
|
||||
```
|
||||
## Test File Name
|
||||
test_player.py
|
||||
## Test Code
|
||||
```python
|
||||
TEST_CONTENT = """
|
||||
import unittest
|
||||
from blackjack_game.player import Player
|
||||
from blackjack_game.deck import Deck
|
||||
|
|
@ -114,42 +113,41 @@ class TestPlayer(unittest.TestCase):
|
|||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
```
|
||||
## Running Command
|
||||
python tests/test_player.py
|
||||
## Running Output
|
||||
standard output: ;
|
||||
standard errors: ..F..
|
||||
======================================================================
|
||||
FAIL: test_player_calculate_score_with_multiple_aces (__main__.TestPlayer)
|
||||
----------------------------------------------------------------------
|
||||
Traceback (most recent call last):
|
||||
File "tests/test_player.py", line 46, in test_player_calculate_score_with_multiple_aces
|
||||
self.assertEqual(player.score, 12)
|
||||
AssertionError: 22 != 12
|
||||
"""
|
||||
|
||||
----------------------------------------------------------------------
|
||||
Ran 5 tests in 0.007s
|
||||
|
||||
FAILED (failures=1)
|
||||
;
|
||||
## instruction:
|
||||
The error is in the development code, specifically in the calculate_score method of the Player class. The method is not correctly handling the case where there are multiple Aces in the player's hand. The current implementation only subtracts 10 from the score once if the score is over 21 and there's an Ace in the hand. However, in the case of multiple Aces, it should subtract 10 for each Ace until the score is 21 or less.
|
||||
## File To Rewrite:
|
||||
player.py
|
||||
## Status:
|
||||
FAIL
|
||||
## Send To:
|
||||
Engineer
|
||||
---
|
||||
'''
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_debug_error():
|
||||
CONFIG.src_workspace = CONFIG.git_repo.workdir / uuid.uuid4().hex
|
||||
ctx = RunCodeContext(
|
||||
code_filename="player.py",
|
||||
test_filename="test_player.py",
|
||||
command=["python", "tests/test_player.py"],
|
||||
output_filename="output.log",
|
||||
)
|
||||
|
||||
debug_error = DebugError("debug_error")
|
||||
await FileRepository.save_file(filename=ctx.code_filename, content=CODE_CONTENT, relative_path=CONFIG.src_workspace)
|
||||
await FileRepository.save_file(filename=ctx.test_filename, content=TEST_CONTENT, relative_path=TEST_CODES_FILE_REPO)
|
||||
output_data = RunCodeResult(
|
||||
stdout=";",
|
||||
stderr="",
|
||||
summary="======================================================================\n"
|
||||
"FAIL: test_player_calculate_score_with_multiple_aces (__main__.TestPlayer)\n"
|
||||
"----------------------------------------------------------------------\n"
|
||||
"Traceback (most recent call last):\n"
|
||||
' File "tests/test_player.py", line 46, in test_player_calculate_score_'
|
||||
"with_multiple_aces\n"
|
||||
" self.assertEqual(player.score, 12)\nAssertionError: 22 != 12\n\n"
|
||||
"----------------------------------------------------------------------\n"
|
||||
"Ran 5 tests in 0.007s\n\nFAILED (failures=1)\n;\n",
|
||||
)
|
||||
await FileRepository.save_file(
|
||||
filename=ctx.output_filename, content=output_data.json(), relative_path=TEST_OUTPUTS_FILE_REPO
|
||||
)
|
||||
debug_error = DebugError(context=ctx)
|
||||
|
||||
file_name, rewritten_code = await debug_error.run(context=EXAMPLE_MSG_CONTENT)
|
||||
rsp = await debug_error.run()
|
||||
|
||||
assert "class Player" in rewritten_code # rewrite the same class
|
||||
assert "while self.score > 21" in rewritten_code # a key logic to rewrite to (original one is "if self.score > 12")
|
||||
assert "class Player" in rsp # rewrite the same class
|
||||
# a key logic to rewrite to (original one is "if self.score > 12")
|
||||
assert "while self.score > 21" in rsp
|
||||
|
|
|
|||
|
|
@ -4,33 +4,27 @@
|
|||
@Time : 2023/5/11 19:26
|
||||
@Author : alexanderwu
|
||||
@File : test_design_api.py
|
||||
@Modifiled By: mashenquan, 2023-12-6. According to RFC 135
|
||||
"""
|
||||
import pytest
|
||||
|
||||
from metagpt.actions.design_api import WriteDesign
|
||||
from metagpt.const import PRDS_FILE_REPO
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.file_repository import FileRepository
|
||||
from tests.metagpt.actions.mock import PRD_SAMPLE
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_design_api():
|
||||
prd = "我们需要一个音乐播放器,它应该有播放、暂停、上一曲、下一曲等功能。"
|
||||
inputs = ["我们需要一个音乐播放器,它应该有播放、暂停、上一曲、下一曲等功能。", PRD_SAMPLE]
|
||||
for prd in inputs:
|
||||
await FileRepository.save_file("new_prd.txt", content=prd, relative_path=PRDS_FILE_REPO)
|
||||
|
||||
design_api = WriteDesign("design_api")
|
||||
design_api = WriteDesign("design_api")
|
||||
|
||||
result = await design_api.run([Message(content=prd, instruct_content=None)])
|
||||
logger.info(result)
|
||||
result = await design_api.run([Message(content=prd, instruct_content=None)])
|
||||
logger.info(result)
|
||||
|
||||
assert result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_design_api_calculator():
|
||||
prd = PRD_SAMPLE
|
||||
|
||||
design_api = WriteDesign("design_api")
|
||||
result = await design_api.run([Message(content=prd, instruct_content=None)])
|
||||
logger.info(result)
|
||||
|
||||
assert result
|
||||
assert result
|
||||
|
|
|
|||
|
|
@ -3,21 +3,27 @@
|
|||
"""
|
||||
@Time : 2023/9/13 00:26
|
||||
@Author : fisherdeng
|
||||
@File : test_detail_mining.py
|
||||
@File : test_generate_questions.py
|
||||
"""
|
||||
import pytest
|
||||
|
||||
from metagpt.actions.detail_mining import DetailMining
|
||||
from metagpt.actions.generate_questions import GenerateQuestions
|
||||
from metagpt.logs import logger
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_detail_mining():
|
||||
topic = "如何做一个生日蛋糕"
|
||||
record = "我认为应该先准备好材料,然后再开始做蛋糕。"
|
||||
detail_mining = DetailMining("detail_mining")
|
||||
rsp = await detail_mining.run(topic=topic, record=record)
|
||||
logger.info(f"{rsp.content=}")
|
||||
|
||||
assert '##OUTPUT' in rsp.content
|
||||
assert '蛋糕' in rsp.content
|
||||
context = """
|
||||
## topic
|
||||
如何做一个生日蛋糕
|
||||
|
||||
## record
|
||||
我认为应该先准备好材料,然后再开始做蛋糕。
|
||||
"""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_questions():
|
||||
detail_mining = GenerateQuestions()
|
||||
rsp = await detail_mining.run(context)
|
||||
logger.info(f"{rsp.content=}")
|
||||
|
||||
assert "Questions" in rsp.content
|
||||
assert "1." in rsp.content
|
||||
|
|
|
|||
|
|
@ -8,12 +8,11 @@
|
|||
"""
|
||||
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
|
||||
from metagpt.actions.invoice_ocr import InvoiceOCR, GenerateTable, ReplyQuestion
|
||||
import pytest
|
||||
|
||||
from metagpt.actions.invoice_ocr import GenerateTable, InvoiceOCR, ReplyQuestion
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -22,7 +21,7 @@ from metagpt.actions.invoice_ocr import InvoiceOCR, GenerateTable, ReplyQuestion
|
|||
[
|
||||
"../../data/invoices/invoice-3.jpg",
|
||||
"../../data/invoices/invoice-4.zip",
|
||||
]
|
||||
],
|
||||
)
|
||||
async def test_invoice_ocr(invoice_path: str):
|
||||
invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path))
|
||||
|
|
@ -35,18 +34,8 @@ async def test_invoice_ocr(invoice_path: str):
|
|||
@pytest.mark.parametrize(
|
||||
("invoice_path", "expected_result"),
|
||||
[
|
||||
(
|
||||
"../../data/invoices/invoice-1.pdf",
|
||||
[
|
||||
{
|
||||
"收款人": "小明",
|
||||
"城市": "深圳市",
|
||||
"总费用/元": "412.00",
|
||||
"开票日期": "2023年02月03日"
|
||||
}
|
||||
]
|
||||
),
|
||||
]
|
||||
("../../data/invoices/invoice-1.pdf", [{"收款人": "小明", "城市": "深圳市", "总费用/元": "412.00", "开票日期": "2023年02月03日"}]),
|
||||
],
|
||||
)
|
||||
async def test_generate_table(invoice_path: str, expected_result: list[dict]):
|
||||
invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path))
|
||||
|
|
@ -59,9 +48,7 @@ async def test_generate_table(invoice_path: str, expected_result: list[dict]):
|
|||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
("invoice_path", "query", "expected_result"),
|
||||
[
|
||||
("../../data/invoices/invoice-1.pdf", "Invoicing date", "2023年02月03日")
|
||||
]
|
||||
[("../../data/invoices/invoice-1.pdf", "Invoicing date", "2023年02月03日")],
|
||||
)
|
||||
async def test_reply_question(invoice_path: str, query: dict, expected_result: str):
|
||||
invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path))
|
||||
|
|
@ -69,4 +56,3 @@ async def test_reply_question(invoice_path: str, query: dict, expected_result: s
|
|||
ocr_result = await InvoiceOCR().run(file_path=Path(invoice_path), filename=filename)
|
||||
result = await ReplyQuestion().run(query=query, ocr_result=ocr_result)
|
||||
assert expected_result in result
|
||||
|
||||
|
|
|
|||
30
tests/metagpt/actions/test_prepare_documents.py
Normal file
30
tests/metagpt/actions/test_prepare_documents.py
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/12/6
|
||||
@Author : mashenquan
|
||||
@File : test_prepare_documents.py
|
||||
@Desc: Unit test for prepare_documents.py
|
||||
"""
|
||||
import pytest
|
||||
|
||||
from metagpt.actions.prepare_documents import PrepareDocuments
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import DOCS_FILE_REPO, REQUIREMENT_FILENAME
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.file_repository import FileRepository
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prepare_documents():
|
||||
msg = Message(content="New user requirements balabala...")
|
||||
|
||||
if CONFIG.git_repo:
|
||||
CONFIG.git_repo.delete_repository()
|
||||
CONFIG.git_repo = None
|
||||
|
||||
await PrepareDocuments().run(with_messages=[msg])
|
||||
assert CONFIG.git_repo
|
||||
doc = await FileRepository.get_file(filename=REQUIREMENT_FILENAME, relative_path=DOCS_FILE_REPO)
|
||||
assert doc
|
||||
assert doc.content == msg.content
|
||||
21
tests/metagpt/actions/test_prepare_interview.py
Normal file
21
tests/metagpt/actions/test_prepare_interview.py
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/9/13 00:26
|
||||
@Author : fisherdeng
|
||||
@File : test_detail_mining.py
|
||||
"""
|
||||
import pytest
|
||||
|
||||
from metagpt.actions.prepare_interview import PrepareInterview
|
||||
from metagpt.logs import logger
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prepare_interview():
|
||||
action = PrepareInterview()
|
||||
rsp = await action.run("I just graduated and hope to find a job as a Python engineer")
|
||||
logger.info(f"{rsp.content=}")
|
||||
|
||||
assert "Questions" in rsp.content
|
||||
assert "1." in rsp.content
|
||||
|
|
@ -4,10 +4,12 @@
|
|||
@Time : 2023/5/11 17:46
|
||||
@Author : alexanderwu
|
||||
@File : test_run_code.py
|
||||
@Modifiled By: mashenquan, 2023-12-6. According to RFC 135
|
||||
"""
|
||||
import pytest
|
||||
|
||||
from metagpt.actions.run_code import RunCode
|
||||
from metagpt.schema import RunCodeContext
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -35,37 +37,29 @@ async def test_run_script():
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run():
|
||||
action = RunCode()
|
||||
result = await action.run(mode="text", code="print('Hello, World')")
|
||||
assert "PASS" in result
|
||||
|
||||
result = await action.run(
|
||||
mode="script",
|
||||
code="echo 'Hello World'",
|
||||
code_file_name="",
|
||||
test_code="",
|
||||
test_file_name="",
|
||||
command=["echo", "Hello World"],
|
||||
working_directory=".",
|
||||
additional_python_paths=[],
|
||||
)
|
||||
assert "PASS" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_failure():
|
||||
action = RunCode()
|
||||
result = await action.run(mode="text", code="result = 1 / 0")
|
||||
assert "FAIL" in result
|
||||
|
||||
result = await action.run(
|
||||
mode="script",
|
||||
code='python -c "print(1/0)"',
|
||||
code_file_name="",
|
||||
test_code="",
|
||||
test_file_name="",
|
||||
command=["python", "-c", "print(1/0)"],
|
||||
working_directory=".",
|
||||
additional_python_paths=[],
|
||||
)
|
||||
assert "FAIL" in result
|
||||
inputs = [
|
||||
(RunCodeContext(mode="text", code_filename="a.txt", code="print('Hello, World')"), "PASS"),
|
||||
(
|
||||
RunCodeContext(
|
||||
mode="script",
|
||||
code_filename="a.sh",
|
||||
code="echo 'Hello World'",
|
||||
command=["echo", "Hello World"],
|
||||
working_directory=".",
|
||||
),
|
||||
"PASS",
|
||||
),
|
||||
(
|
||||
RunCodeContext(
|
||||
mode="script",
|
||||
code_filename="a.py",
|
||||
code='python -c "print(1/0)"',
|
||||
command=["python", "-c", "print(1/0)"],
|
||||
working_directory=".",
|
||||
),
|
||||
"FAIL",
|
||||
),
|
||||
]
|
||||
for ctx, result in inputs:
|
||||
rsp = await RunCode(context=ctx).run()
|
||||
assert result in rsp.summary
|
||||
|
|
|
|||
195
tests/metagpt/actions/test_summarize_code.py
Normal file
195
tests/metagpt/actions/test_summarize_code.py
Normal file
|
|
@ -0,0 +1,195 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/5/11 17:46
|
||||
@Author : mashenquan
|
||||
@File : test_summarize_code.py
|
||||
@Modifiled By: mashenquan, 2023-12-6. Unit test for summarize_code.py
|
||||
"""
|
||||
import pytest
|
||||
|
||||
from metagpt.actions.summarize_code import SummarizeCode
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import SYSTEM_DESIGN_FILE_REPO, TASK_FILE_REPO
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import CodeSummarizeContext
|
||||
from metagpt.utils.file_repository import FileRepository
|
||||
|
||||
DESIGN_CONTENT = """
|
||||
{"Implementation approach": "To develop this snake game, we will use the Python language and choose the Pygame library. Pygame is an open-source Python module collection specifically designed for writing video games. It provides functionalities such as displaying images and playing sounds, making it suitable for creating intuitive and responsive user interfaces. We will ensure efficient game logic to prevent any delays during gameplay. The scoring system will be simple, with the snake gaining points for each food it eats. We will use Pygame's event handling system to implement pause and resume functionality, as well as high-score tracking. The difficulty will increase by speeding up the snake's movement. In the initial version, we will focus on single-player mode and consider adding multiplayer mode and customizable skins in future updates. Based on the new requirement, we will also add a moving obstacle that appears randomly. If the snake eats this obstacle, the game will end. If the snake does not eat the obstacle, it will disappear after 5 seconds. For this, we need to add mechanisms for obstacle generation, movement, and disappearance in the game logic.", "Project_name": "snake_game", "File list": ["main.py", "game.py", "snake.py", "food.py", "obstacle.py", "scoreboard.py", "constants.py", "assets/styles.css", "assets/index.html"], "Data structures and interfaces": "```mermaid\n classDiagram\n class Game{\n +int score\n +int speed\n +bool game_over\n +bool paused\n +Snake snake\n +Food food\n +Obstacle obstacle\n +Scoreboard scoreboard\n +start_game() void\n +pause_game() void\n +resume_game() void\n +end_game() void\n +increase_difficulty() void\n +update() void\n +render() void\n Game()\n }\n class Snake{\n +list body_parts\n +str direction\n +bool grow\n +move() void\n +grow() void\n +check_collision() bool\n Snake()\n }\n class Food{\n +tuple position\n +spawn() void\n Food()\n }\n class Obstacle{\n +tuple position\n +int lifetime\n +bool active\n +spawn() void\n +move() void\n +check_collision() bool\n +disappear() void\n Obstacle()\n }\n class Scoreboard{\n +int high_score\n +update_score(int) void\n +reset_score() void\n +load_high_score() void\n +save_high_score() void\n Scoreboard()\n }\n class Constants{\n }\n Game \"1\" -- \"1\" Snake: has\n Game \"1\" -- \"1\" Food: has\n Game \"1\" -- \"1\" Obstacle: has\n Game \"1\" -- \"1\" Scoreboard: has\n ```", "Program call flow": "```sequenceDiagram\n participant M as Main\n participant G as Game\n participant S as Snake\n participant F as Food\n participant O as Obstacle\n participant SB as Scoreboard\n M->>G: start_game()\n loop game loop\n G->>S: move()\n G->>S: check_collision()\n G->>F: spawn()\n G->>O: spawn()\n G->>O: move()\n G->>O: check_collision()\n G->>O: disappear()\n G->>SB: update_score(score)\n G->>G: update()\n G->>G: render()\n alt if paused\n M->>G: pause_game()\n M->>G: resume_game()\n end\n alt if game_over\n G->>M: end_game()\n end\n end\n```", "Anything UNCLEAR": "There is no need for further clarification as the requirements are already clear."}
|
||||
"""
|
||||
|
||||
TASK_CONTENT = """
|
||||
{"Required Python third-party packages": ["pygame==2.0.1"], "Required Other language third-party packages": ["No third-party packages required for other languages."], "Full API spec": "\n openapi: 3.0.0\n info:\n title: Snake Game API\n version: \"1.0.0\"\n paths:\n /start:\n get:\n summary: Start the game\n responses:\n '200':\n description: Game started successfully\n /pause:\n get:\n summary: Pause the game\n responses:\n '200':\n description: Game paused successfully\n /resume:\n get:\n summary: Resume the game\n responses:\n '200':\n description: Game resumed successfully\n /end:\n get:\n summary: End the game\n responses:\n '200':\n description: Game ended successfully\n /score:\n get:\n summary: Get the current score\n responses:\n '200':\n description: Current score retrieved successfully\n /highscore:\n get:\n summary: Get the high score\n responses:\n '200':\n description: High score retrieved successfully\n components: {}\n ", "Logic Analysis": [["constants.py", "Contains all the constant values like screen size, colors, game speeds, etc. This should be implemented first as it provides the base values for other components."], ["snake.py", "Contains the Snake class with methods for movement, growth, and collision detection. It is dependent on constants.py for configuration values."], ["food.py", "Contains the Food class responsible for spawning food items on the screen. It is dependent on constants.py for configuration values."], ["obstacle.py", "Contains the Obstacle class with methods for spawning, moving, and disappearing of obstacles, as well as collision detection with the snake. It is dependent on constants.py for configuration values."], ["scoreboard.py", "Contains the Scoreboard class for updating, resetting, loading, and saving high scores. It may use constants.py for configuration values and depends on the game's scoring logic."], ["game.py", "Contains the main Game class which includes the game loop and methods for starting, pausing, resuming, and ending the game. It is dependent on snake.py, food.py, obstacle.py, and scoreboard.py."], ["main.py", "The entry point of the game that initializes the game and starts the game loop. It is dependent on game.py."]], "Task list": ["constants.py", "snake.py", "food.py", "obstacle.py", "scoreboard.py", "game.py", "main.py"], "Shared Knowledge": "\n 'constants.py' should contain all the necessary configurations for the game, such as screen dimensions, color definitions, and speed settings. These constants will be used across multiple files, ensuring consistency and ease of updates. Ensure that the Pygame library is initialized correctly in 'main.py' before starting the game loop. Also, make sure that the game's state is managed properly when pausing and resuming the game.\n ", "Anything UNCLEAR": "The interaction between the 'obstacle.py' and the game loop needs to be clearly defined to ensure obstacles appear and disappear correctly. The lifetime of the obstacle and its random movement should be implemented in a way that does not interfere with the game's performance."}
|
||||
"""
|
||||
|
||||
FOOD_PY = """
|
||||
## food.py
|
||||
import random
|
||||
|
||||
class Food:
|
||||
def __init__(self):
|
||||
self.position = (0, 0)
|
||||
|
||||
def generate(self):
|
||||
x = random.randint(0, 9)
|
||||
y = random.randint(0, 9)
|
||||
self.position = (x, y)
|
||||
|
||||
def get_position(self):
|
||||
return self.position
|
||||
|
||||
"""
|
||||
|
||||
GAME_PY = """
|
||||
## game.py
|
||||
import pygame
|
||||
from snake import Snake
|
||||
from food import Food
|
||||
|
||||
class Game:
|
||||
def __init__(self):
|
||||
self.score = 0
|
||||
self.level = 1
|
||||
self.snake = Snake()
|
||||
self.food = Food()
|
||||
|
||||
def start_game(self):
|
||||
pygame.init()
|
||||
self.initialize_game()
|
||||
self.game_loop()
|
||||
|
||||
def initialize_game(self):
|
||||
self.score = 0
|
||||
self.level = 1
|
||||
self.snake.reset()
|
||||
self.food.generate()
|
||||
|
||||
def game_loop(self):
|
||||
game_over = False
|
||||
|
||||
while not game_over:
|
||||
self.update()
|
||||
self.draw()
|
||||
self.handle_events()
|
||||
self.check_collision()
|
||||
self.increase_score()
|
||||
self.increase_level()
|
||||
|
||||
if self.snake.is_collision():
|
||||
game_over = True
|
||||
self.game_over()
|
||||
|
||||
def update(self):
|
||||
self.snake.move()
|
||||
|
||||
def draw(self):
|
||||
self.snake.draw()
|
||||
self.food.draw()
|
||||
|
||||
def handle_events(self):
|
||||
for event in pygame.event.get():
|
||||
if event.type == pygame.QUIT:
|
||||
pygame.quit()
|
||||
quit()
|
||||
elif event.type == pygame.KEYDOWN:
|
||||
if event.key == pygame.K_UP:
|
||||
self.snake.change_direction("UP")
|
||||
elif event.key == pygame.K_DOWN:
|
||||
self.snake.change_direction("DOWN")
|
||||
elif event.key == pygame.K_LEFT:
|
||||
self.snake.change_direction("LEFT")
|
||||
elif event.key == pygame.K_RIGHT:
|
||||
self.snake.change_direction("RIGHT")
|
||||
|
||||
def check_collision(self):
|
||||
if self.snake.get_head() == self.food.get_position():
|
||||
self.snake.grow()
|
||||
self.food.generate()
|
||||
|
||||
def increase_score(self):
|
||||
self.score += 1
|
||||
|
||||
def increase_level(self):
|
||||
if self.score % 10 == 0:
|
||||
self.level += 1
|
||||
|
||||
def game_over(self):
|
||||
print("Game Over")
|
||||
self.initialize_game()
|
||||
|
||||
"""
|
||||
|
||||
MAIN_PY = """
|
||||
## main.py
|
||||
import pygame
|
||||
from game import Game
|
||||
|
||||
def main():
|
||||
pygame.init()
|
||||
game = Game()
|
||||
game.start_game()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
"""
|
||||
|
||||
SNAKE_PY = """
|
||||
## snake.py
|
||||
import pygame
|
||||
|
||||
class Snake:
|
||||
def __init__(self):
|
||||
self.body = [(0, 0)]
|
||||
self.direction = (1, 0)
|
||||
|
||||
def move(self):
|
||||
head = self.body[0]
|
||||
dx, dy = self.direction
|
||||
new_head = (head[0] + dx, head[1] + dy)
|
||||
self.body.insert(0, new_head)
|
||||
self.body.pop()
|
||||
|
||||
def change_direction(self, direction):
|
||||
if direction == "UP":
|
||||
self.direction = (0, -1)
|
||||
elif direction == "DOWN":
|
||||
self.direction = (0, 1)
|
||||
elif direction == "LEFT":
|
||||
self.direction = (-1, 0)
|
||||
elif direction == "RIGHT":
|
||||
self.direction = (1, 0)
|
||||
|
||||
def grow(self):
|
||||
tail = self.body[-1]
|
||||
dx, dy = self.direction
|
||||
new_tail = (tail[0] - dx, tail[1] - dy)
|
||||
self.body.append(new_tail)
|
||||
|
||||
def get_head(self):
|
||||
return self.body[0]
|
||||
|
||||
def get_body(self):
|
||||
return self.body[1:]
|
||||
|
||||
"""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_summarize_code():
|
||||
CONFIG.src_workspace = CONFIG.git_repo.workdir / "src"
|
||||
await FileRepository.save_file(filename="1.json", relative_path=SYSTEM_DESIGN_FILE_REPO, content=DESIGN_CONTENT)
|
||||
await FileRepository.save_file(filename="1.json", relative_path=TASK_FILE_REPO, content=TASK_CONTENT)
|
||||
await FileRepository.save_file(filename="food.py", relative_path=CONFIG.src_workspace, content=FOOD_PY)
|
||||
await FileRepository.save_file(filename="game.py", relative_path=CONFIG.src_workspace, content=GAME_PY)
|
||||
await FileRepository.save_file(filename="main.py", relative_path=CONFIG.src_workspace, content=MAIN_PY)
|
||||
await FileRepository.save_file(filename="snake.py", relative_path=CONFIG.src_workspace, content=SNAKE_PY)
|
||||
|
||||
src_file_repo = CONFIG.git_repo.new_file_repository(relative_path=CONFIG.src_workspace)
|
||||
all_files = src_file_repo.all_files
|
||||
ctx = CodeSummarizeContext(design_filename="1.json", task_filename="1.json", codes_filenames=all_files)
|
||||
action = SummarizeCode(context=ctx)
|
||||
rsp = await action.run()
|
||||
assert rsp
|
||||
logger.info(rsp)
|
||||
|
|
@ -1,10 +1,10 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Date : 2023/7/22 02:40
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Author : stellahong (stellahong@deepwisdom.ai)
|
||||
#
|
||||
from tests.metagpt.roles.ui_role import UIDesign
|
||||
|
||||
llm_resp= '''
|
||||
llm_resp = """
|
||||
# UI Design Description
|
||||
```The user interface for the snake game will be designed in a way that is simple, clean, and intuitive. The main elements of the game such as the game grid, snake, food, score, and game over message will be clearly defined and easy to understand. The game grid will be centered on the screen with the score displayed at the top. The game controls will be intuitive and easy to use. The design will be modern and minimalist with a pleasing color scheme.```
|
||||
|
||||
|
|
@ -98,12 +98,13 @@ body {
|
|||
left: 50%;
|
||||
transform: translate(-50%, -50%);
|
||||
font-size: 3em;
|
||||
'''
|
||||
"""
|
||||
|
||||
|
||||
def test_ui_design_parse_css():
|
||||
ui_design_work = UIDesign(name="UI design action")
|
||||
|
||||
css = '''
|
||||
css = """
|
||||
body {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
|
|
@ -160,14 +161,14 @@ def test_ui_design_parse_css():
|
|||
left: 50%;
|
||||
transform: translate(-50%, -50%);
|
||||
font-size: 3em;
|
||||
'''
|
||||
assert ui_design_work.parse_css_code(context=llm_resp)==css
|
||||
"""
|
||||
assert ui_design_work.parse_css_code(context=llm_resp) == css
|
||||
|
||||
|
||||
def test_ui_design_parse_html():
|
||||
ui_design_work = UIDesign(name="UI design action")
|
||||
|
||||
html = '''
|
||||
html = """
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
|
|
@ -184,8 +185,5 @@ def test_ui_design_parse_html():
|
|||
<div class="game-over">Game Over</div>
|
||||
</body>
|
||||
</html>
|
||||
'''
|
||||
assert ui_design_work.parse_css_code(context=llm_resp)==html
|
||||
|
||||
|
||||
|
||||
"""
|
||||
assert ui_design_work.parse_css_code(context=llm_resp) == html
|
||||
|
|
|
|||
|
|
@ -4,31 +4,36 @@
|
|||
@Time : 2023/5/11 17:45
|
||||
@Author : alexanderwu
|
||||
@File : test_write_code.py
|
||||
@Modifiled By: mashenquan, 2023-12-6. According to RFC 135
|
||||
"""
|
||||
import pytest
|
||||
|
||||
from metagpt.actions.write_code import WriteCode
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import CodingContext, Document
|
||||
from tests.metagpt.actions.mock import TASKS_2, WRITE_CODE_PROMPT_SAMPLE
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_code():
|
||||
api_design = "设计一个名为'add'的函数,该函数接受两个整数作为输入,并返回它们的和。"
|
||||
write_code = WriteCode("write_code")
|
||||
context = CodingContext(
|
||||
filename="task_filename.py", design_doc=Document(content="设计一个名为'add'的函数,该函数接受两个整数作为输入,并返回它们的和。")
|
||||
)
|
||||
doc = Document(content=context.json())
|
||||
write_code = WriteCode(context=doc)
|
||||
|
||||
code = await write_code.run(api_design)
|
||||
logger.info(code)
|
||||
code = await write_code.run()
|
||||
logger.info(code.json())
|
||||
|
||||
# 我们不能精确地预测生成的代码,但我们可以检查某些关键字
|
||||
assert 'def add' in code
|
||||
assert 'return' in code
|
||||
assert "def add" in code.code_doc.content
|
||||
assert "return" in code.code_doc.content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_code_directly():
|
||||
prompt = WRITE_CODE_PROMPT_SAMPLE + '\n' + TASKS_2[0]
|
||||
prompt = WRITE_CODE_PROMPT_SAMPLE + "\n" + TASKS_2[0]
|
||||
llm = LLM()
|
||||
rsp = await llm.aask(prompt)
|
||||
logger.info(rsp)
|
||||
|
|
|
|||
|
|
@ -8,6 +8,8 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.actions.write_code_review import WriteCodeReview
|
||||
from metagpt.document import Document
|
||||
from metagpt.schema import CodingContext
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -16,13 +18,15 @@ async def test_write_code_review(capfd):
|
|||
def add(a, b):
|
||||
return a +
|
||||
"""
|
||||
# write_code_review = WriteCodeReview("write_code_review")
|
||||
context = CodingContext(
|
||||
filename="math.py", design_doc=Document(content="编写一个从a加b的函数,返回a+b"), code_doc=Document(content=code)
|
||||
)
|
||||
|
||||
code = await WriteCodeReview().run(context="编写一个从a加b的函数,返回a+b", code=code, filename="math.py")
|
||||
context = await WriteCodeReview(context=context).run()
|
||||
|
||||
# 我们不能精确地预测生成的代码评审,但我们可以检查返回的是否为字符串
|
||||
assert isinstance(code, str)
|
||||
assert len(code) > 0
|
||||
assert isinstance(context.code_doc.content, str)
|
||||
assert len(context.code_doc.content) > 0
|
||||
|
||||
captured = capfd.readouterr()
|
||||
print(f"输出内容: {captured.out}")
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import pytest
|
|||
|
||||
from metagpt.actions.write_docstring import WriteDocstring
|
||||
|
||||
code = '''
|
||||
code = """
|
||||
def add_numbers(a: int, b: int):
|
||||
return a + b
|
||||
|
||||
|
|
@ -14,7 +14,7 @@ class Person:
|
|||
|
||||
def greet(self):
|
||||
return f"Hello, my name is {self.name} and I am {self.age} years old."
|
||||
'''
|
||||
"""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -25,7 +25,7 @@ class Person:
|
|||
("numpy", "Parameters"),
|
||||
("sphinx", ":param name:"),
|
||||
],
|
||||
ids=["google", "numpy", "sphinx"]
|
||||
ids=["google", "numpy", "sphinx"],
|
||||
)
|
||||
async def test_write_docstring(style: str, part: str):
|
||||
ret = await WriteDocstring().run(code, style=style)
|
||||
|
|
|
|||
|
|
@ -4,23 +4,29 @@
|
|||
@Time : 2023/5/11 17:45
|
||||
@Author : alexanderwu
|
||||
@File : test_write_prd.py
|
||||
@Modified By: mashenquan, 2023-11-1. According to Chapter 2.2.1 and 2.2.2 of RFC 116, replace `handle` with `run`.
|
||||
"""
|
||||
import pytest
|
||||
|
||||
from metagpt.actions import BossRequirement
|
||||
from metagpt.actions import UserRequirement
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import DOCS_FILE_REPO, PRDS_FILE_REPO, REQUIREMENT_FILENAME
|
||||
from metagpt.logs import logger
|
||||
from metagpt.roles.product_manager import ProductManager
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.file_repository import FileRepository
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_prd():
|
||||
product_manager = ProductManager()
|
||||
requirements = "开发一个基于大语言模型与私有知识库的搜索引擎,希望可以基于大语言模型进行搜索总结"
|
||||
prd = await product_manager.handle(Message(content=requirements, cause_by=BossRequirement))
|
||||
await FileRepository.save_file(filename=REQUIREMENT_FILENAME, content=requirements, relative_path=DOCS_FILE_REPO)
|
||||
prd = await product_manager.run(Message(content=requirements, cause_by=UserRequirement))
|
||||
logger.info(requirements)
|
||||
logger.info(prd)
|
||||
|
||||
# Assert the prd is not None or empty
|
||||
assert prd is not None
|
||||
assert prd != ""
|
||||
assert prd.content != ""
|
||||
assert CONFIG.git_repo.new_file_repository(relative_path=PRDS_FILE_REPO).changed_files
|
||||
|
|
|
|||
53
tests/metagpt/actions/test_write_review.py
Normal file
53
tests/metagpt/actions/test_write_review.py
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/12/20 15:01
|
||||
@Author : alexanderwu
|
||||
@File : test_write_review.py
|
||||
"""
|
||||
import pytest
|
||||
|
||||
from metagpt.actions.write_review import WriteReview
|
||||
|
||||
CONTEXT = """
|
||||
{
|
||||
"Language": "zh_cn",
|
||||
"Programming Language": "Python",
|
||||
"Original Requirements": "写一个简单的2048",
|
||||
"Project Name": "game_2048",
|
||||
"Product Goals": [
|
||||
"创建一个引人入胜的用户体验",
|
||||
"确保高性能",
|
||||
"提供可定制的功能"
|
||||
],
|
||||
"User Stories": [
|
||||
"作为用户,我希望能够选择不同的难度级别",
|
||||
"作为玩家,我希望在每局游戏结束后能看到我的得分"
|
||||
],
|
||||
"Competitive Analysis": [
|
||||
"Python Snake Game: 界面简单,缺乏高级功能"
|
||||
],
|
||||
"Competitive Quadrant Chart": "quadrantChart\n title \"Reach and engagement of campaigns\"\n x-axis \"Low Reach\" --> \"High Reach\"\n y-axis \"Low Engagement\" --> \"High Engagement\"\n quadrant-1 \"我们应该扩展\"\n quadrant-2 \"需要推广\"\n quadrant-3 \"重新评估\"\n quadrant-4 \"可能需要改进\"\n \"Campaign A\": [0.3, 0.6]\n \"Campaign B\": [0.45, 0.23]\n \"Campaign C\": [0.57, 0.69]\n \"Campaign D\": [0.78, 0.34]\n \"Campaign E\": [0.40, 0.34]\n \"Campaign F\": [0.35, 0.78]\n \"Our Target Product\": [0.5, 0.6]",
|
||||
"Requirement Analysis": "产品应该用户友好。",
|
||||
"Requirement Pool": [
|
||||
[
|
||||
"P0",
|
||||
"主要代码..."
|
||||
],
|
||||
[
|
||||
"P0",
|
||||
"游戏算法..."
|
||||
]
|
||||
],
|
||||
"UI Design draft": "基本功能描述,简单的风格和布局。",
|
||||
"Anything UNCLEAR": "..."
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_review():
|
||||
write_review = WriteReview()
|
||||
review = await write_review.run(CONTEXT)
|
||||
assert review.instruct_content
|
||||
assert review.get("LGTM") in ["LGTM", "LBTM"]
|
||||
|
|
@ -9,6 +9,7 @@ import pytest
|
|||
|
||||
from metagpt.actions.write_test import WriteTest
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import Document, TestingContext
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -24,22 +25,17 @@ async def test_write_test():
|
|||
def generate(self, max_y: int, max_x: int):
|
||||
self.position = (random.randint(1, max_y - 1), random.randint(1, max_x - 1))
|
||||
"""
|
||||
context = TestingContext(filename="food.py", code_doc=Document(filename="food.py", content=code))
|
||||
write_test = WriteTest(context=context)
|
||||
|
||||
write_test = WriteTest()
|
||||
|
||||
test_code = await write_test.run(
|
||||
code_to_test=code,
|
||||
test_file_name="test_food.py",
|
||||
source_file_path="/some/dummy/path/cli_snake_game/cli_snake_game/food.py",
|
||||
workspace="/some/dummy/path/cli_snake_game",
|
||||
)
|
||||
logger.info(test_code)
|
||||
context = await write_test.run()
|
||||
logger.info(context.json())
|
||||
|
||||
# We cannot exactly predict the generated test cases, but we can check if it is a string and if it is not empty
|
||||
assert isinstance(test_code, str)
|
||||
assert "from cli_snake_game.food import Food" in test_code
|
||||
assert "class TestFood(unittest.TestCase)" in test_code
|
||||
assert "def test_generate" in test_code
|
||||
assert isinstance(context.test_doc.content, str)
|
||||
assert "from food import Food" in context.test_doc.content
|
||||
assert "class TestFood(unittest.TestCase)" in context.test_doc.content
|
||||
assert "def test_generate" in context.test_doc.content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
|
|||
|
|
@ -9,14 +9,11 @@ from typing import Dict
|
|||
|
||||
import pytest
|
||||
|
||||
from metagpt.actions.write_tutorial import WriteDirectory, WriteContent
|
||||
from metagpt.actions.write_tutorial import WriteContent, WriteDirectory
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
("language", "topic"),
|
||||
[("English", "Write a tutorial about Python")]
|
||||
)
|
||||
@pytest.mark.parametrize(("language", "topic"), [("English", "Write a tutorial about Python")])
|
||||
async def test_write_directory(language: str, topic: str):
|
||||
ret = await WriteDirectory(language=language).run(topic=topic)
|
||||
assert isinstance(ret, dict)
|
||||
|
|
@ -30,7 +27,7 @@ async def test_write_directory(language: str, topic: str):
|
|||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
("language", "topic", "directory"),
|
||||
[("English", "Write a tutorial about Python", {"Introduction": ["What is Python?", "Why learn Python?"]})]
|
||||
[("English", "Write a tutorial about Python", {"Introduction": ["What is Python?", "Why learn Python?"]})],
|
||||
)
|
||||
async def test_write_content(language: str, topic: str, directory: Dict):
|
||||
ret = await WriteContent(language=language, directory=directory).run(topic=topic)
|
||||
|
|
|
|||
|
|
@ -12,12 +12,12 @@ from metagpt.document_store.chromadb_store import ChromaStore
|
|||
def test_chroma_store():
|
||||
"""FIXME:chroma使用感觉很诡异,一用Python就挂,测试用例里也是"""
|
||||
# 创建 ChromaStore 实例,使用 'sample_collection' 集合
|
||||
document_store = ChromaStore('sample_collection_1')
|
||||
document_store = ChromaStore("sample_collection_1")
|
||||
|
||||
# 使用 write 方法添加多个文档
|
||||
document_store.write(["This is document1", "This is document2"],
|
||||
[{"source": "google-docs"}, {"source": "notion"}],
|
||||
["doc1", "doc2"])
|
||||
document_store.write(
|
||||
["This is document1", "This is document2"], [{"source": "google-docs"}, {"source": "notion"}], ["doc1", "doc2"]
|
||||
)
|
||||
|
||||
# 使用 add 方法添加一个文档
|
||||
document_store.add("This is document3", {"source": "notion"}, "doc3")
|
||||
|
|
|
|||
|
|
@ -7,22 +7,22 @@
|
|||
"""
|
||||
import pytest
|
||||
|
||||
from metagpt.const import DATA_PATH
|
||||
from metagpt.document_store.document import Document
|
||||
from metagpt.const import METAGPT_ROOT
|
||||
from metagpt.document import IndexableDocument
|
||||
|
||||
CASES = [
|
||||
("st/faq.xlsx", "Question", "Answer", 1),
|
||||
("cases/faq.csv", "Question", "Answer", 1),
|
||||
("requirements.txt", None, None, 0),
|
||||
# ("cases/faq.csv", "Question", "Answer", 1),
|
||||
# ("cases/faq.json", "Question", "Answer", 1),
|
||||
("docx/faq.docx", None, None, 1),
|
||||
("cases/faq.pdf", None, None, 0), # 这是因为pdf默认没有分割段落
|
||||
("cases/faq.txt", None, None, 0), # 这是因为txt按照256分割段落
|
||||
# ("docx/faq.docx", None, None, 1),
|
||||
# ("cases/faq.pdf", None, None, 0), # 这是因为pdf默认没有分割段落
|
||||
# ("cases/faq.txt", None, None, 0), # 这是因为txt按照256分割段落
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("relative_path, content_col, meta_col, threshold", CASES)
|
||||
def test_document(relative_path, content_col, meta_col, threshold):
|
||||
doc = Document(DATA_PATH / relative_path, content_col, meta_col)
|
||||
doc = IndexableDocument.from_path(METAGPT_ROOT / relative_path, content_col, meta_col)
|
||||
rsp = doc.get_docs_and_metadatas()
|
||||
assert len(rsp[0]) > threshold
|
||||
assert len(rsp[1]) > threshold
|
||||
|
|
|
|||
|
|
@ -39,11 +39,11 @@ user: 没有了
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_faiss_store_search():
|
||||
store = FaissStore(DATA_PATH / 'qcs/qcs_4w.json')
|
||||
store.add(['油皮洗面奶'])
|
||||
store = FaissStore(DATA_PATH / "qcs/qcs_4w.json")
|
||||
store.add(["油皮洗面奶"])
|
||||
role = Sales(store=store)
|
||||
|
||||
queries = ['油皮洗面奶', '介绍下欧莱雅的']
|
||||
queries = ["油皮洗面奶", "介绍下欧莱雅的"]
|
||||
for query in queries:
|
||||
rsp = await role.run(query)
|
||||
assert rsp
|
||||
|
|
@ -60,7 +60,10 @@ def customer_service():
|
|||
async def test_faiss_store_customer_service():
|
||||
allq = [
|
||||
# ["我的餐怎么两小时都没到", "退货吧"],
|
||||
["你好收不到取餐码,麻烦帮我开箱", "14750187158", ]
|
||||
[
|
||||
"你好收不到取餐码,麻烦帮我开箱",
|
||||
"14750187158",
|
||||
]
|
||||
]
|
||||
role = customer_service()
|
||||
for queries in allq:
|
||||
|
|
@ -71,4 +74,4 @@ async def test_faiss_store_customer_service():
|
|||
|
||||
def test_faiss_store_no_file():
|
||||
with pytest.raises(FileNotFoundError):
|
||||
FaissStore(DATA_PATH / 'wtf.json')
|
||||
FaissStore(DATA_PATH / "wtf.json")
|
||||
|
|
|
|||
|
|
@ -5,27 +5,33 @@
|
|||
@Author : unkn-wn (Leon Yee)
|
||||
@File : test_lancedb_store.py
|
||||
"""
|
||||
from metagpt.document_store.lancedb_store import LanceStore
|
||||
import pytest
|
||||
import random
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.document_store.lancedb_store import LanceStore
|
||||
|
||||
|
||||
@pytest
|
||||
def test_lance_store():
|
||||
|
||||
# This simply establishes the connection to the database, so we can drop the table if it exists
|
||||
store = LanceStore('test')
|
||||
store = LanceStore("test")
|
||||
|
||||
store.drop('test')
|
||||
store.drop("test")
|
||||
|
||||
store.write(data=[[random.random() for _ in range(100)] for _ in range(2)],
|
||||
metadatas=[{"source": "google-docs"}, {"source": "notion"}],
|
||||
ids=["doc1", "doc2"])
|
||||
store.write(
|
||||
data=[[random.random() for _ in range(100)] for _ in range(2)],
|
||||
metadatas=[{"source": "google-docs"}, {"source": "notion"}],
|
||||
ids=["doc1", "doc2"],
|
||||
)
|
||||
|
||||
store.add(data=[random.random() for _ in range(100)], metadata={"source": "notion"}, _id="doc3")
|
||||
|
||||
result = store.search([random.random() for _ in range(100)], n_results=3)
|
||||
assert(len(result) == 3)
|
||||
assert len(result) == 3
|
||||
|
||||
store.delete("doc2")
|
||||
result = store.search([random.random() for _ in range(100)], n_results=3, where="source = 'notion'", metric='cosine')
|
||||
assert(len(result) == 1)
|
||||
result = store.search(
|
||||
[random.random() for _ in range(100)], n_results=3, where="source = 'notion'", metric="cosine"
|
||||
)
|
||||
assert len(result) == 1
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ import numpy as np
|
|||
from metagpt.document_store.milvus_store import MilvusConnection, MilvusStore
|
||||
from metagpt.logs import logger
|
||||
|
||||
book_columns = {'idx': int, 'name': str, 'desc': str, 'emb': np.ndarray, 'price': float}
|
||||
book_columns = {"idx": int, "name": str, "desc": str, "emb": np.ndarray, "price": float}
|
||||
book_data = [
|
||||
[i for i in range(10)],
|
||||
[f"book-{i}" for i in range(10)],
|
||||
|
|
@ -25,12 +25,12 @@ book_data = [
|
|||
def test_milvus_store():
|
||||
milvus_connection = MilvusConnection(alias="default", host="192.168.50.161", port="30530")
|
||||
milvus_store = MilvusStore(milvus_connection)
|
||||
milvus_store.drop('Book')
|
||||
milvus_store.create_collection('Book', book_columns)
|
||||
milvus_store.drop("Book")
|
||||
milvus_store.create_collection("Book", book_columns)
|
||||
milvus_store.add(book_data)
|
||||
milvus_store.build_index('emb')
|
||||
milvus_store.build_index("emb")
|
||||
milvus_store.load_collection()
|
||||
|
||||
results = milvus_store.search([[1.0, 1.0]], field='emb')
|
||||
results = milvus_store.search([[1.0, 1.0]], field="emb")
|
||||
logger.info(results)
|
||||
assert results
|
||||
|
|
|
|||
|
|
@ -24,9 +24,7 @@ random.seed(seed_value)
|
|||
vectors = [[random.random() for _ in range(2)] for _ in range(10)]
|
||||
|
||||
points = [
|
||||
PointStruct(
|
||||
id=idx, vector=vector, payload={"color": "red", "rand_number": idx % 10}
|
||||
)
|
||||
PointStruct(id=idx, vector=vector, payload={"color": "red", "rand_number": idx % 10})
|
||||
for idx, vector in enumerate(vectors)
|
||||
]
|
||||
|
||||
|
|
@ -57,9 +55,7 @@ def test_milvus_store():
|
|||
results = qdrant_store.search(
|
||||
"Book",
|
||||
query=[1.0, 1.0],
|
||||
query_filter=Filter(
|
||||
must=[FieldCondition(key="rand_number", range=Range(gte=8))]
|
||||
),
|
||||
query_filter=Filter(must=[FieldCondition(key="rand_number", range=Range(gte=8))]),
|
||||
)
|
||||
assert results[0]["id"] == 8
|
||||
assert results[0]["score"] == 0.9100373450784073
|
||||
|
|
@ -68,9 +64,7 @@ def test_milvus_store():
|
|||
results = qdrant_store.search(
|
||||
"Book",
|
||||
query=[1.0, 1.0],
|
||||
query_filter=Filter(
|
||||
must=[FieldCondition(key="rand_number", range=Range(gte=8))]
|
||||
),
|
||||
query_filter=Filter(must=[FieldCondition(key="rand_number", range=Range(gte=8))]),
|
||||
return_vector=True,
|
||||
)
|
||||
assert results[0]["vector"] == [0.35037919878959656, 0.9366079568862915]
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ def test_skill_manager():
|
|||
|
||||
rsp = manager.retrieve_skill("写测试用例")
|
||||
logger.info(rsp)
|
||||
assert rsp[0] == 'WriteTest'
|
||||
assert rsp[0] == "WriteTest"
|
||||
|
||||
rsp = manager.retrieve_skill_scored("写PRD")
|
||||
logger.info(rsp)
|
||||
|
|
|
|||
|
|
@ -1,12 +1,14 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : unittest of `metagpt/memory/longterm_memory.py`
|
||||
"""
|
||||
@Desc : unittest of `metagpt/memory/longterm_memory.py`
|
||||
"""
|
||||
|
||||
from metagpt.actions import UserRequirement
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.schema import Message
|
||||
from metagpt.actions import BossRequirement
|
||||
from metagpt.roles.role import RoleContext
|
||||
from metagpt.memory import LongTermMemory
|
||||
from metagpt.roles.role import RoleContext
|
||||
from metagpt.schema import Message
|
||||
|
||||
|
||||
def test_ltm_search():
|
||||
|
|
@ -14,25 +16,25 @@ def test_ltm_search():
|
|||
openai_api_key = CONFIG.openai_api_key
|
||||
assert len(openai_api_key) > 20
|
||||
|
||||
role_id = 'UTUserLtm(Product Manager)'
|
||||
rc = RoleContext(watch=[BossRequirement])
|
||||
role_id = "UTUserLtm(Product Manager)"
|
||||
rc = RoleContext(watch=[UserRequirement])
|
||||
ltm = LongTermMemory()
|
||||
ltm.recover_memory(role_id, rc)
|
||||
|
||||
idea = 'Write a cli snake game'
|
||||
message = Message(role='BOSS', content=idea, cause_by=BossRequirement)
|
||||
idea = "Write a cli snake game"
|
||||
message = Message(role="User", content=idea, cause_by=UserRequirement)
|
||||
news = ltm.find_news([message])
|
||||
assert len(news) == 1
|
||||
ltm.add(message)
|
||||
|
||||
sim_idea = 'Write a game of cli snake'
|
||||
sim_message = Message(role='BOSS', content=sim_idea, cause_by=BossRequirement)
|
||||
sim_idea = "Write a game of cli snake"
|
||||
sim_message = Message(role="User", content=sim_idea, cause_by=UserRequirement)
|
||||
news = ltm.find_news([sim_message])
|
||||
assert len(news) == 0
|
||||
ltm.add(sim_message)
|
||||
|
||||
new_idea = 'Write a 2048 web game'
|
||||
new_message = Message(role='BOSS', content=new_idea, cause_by=BossRequirement)
|
||||
new_idea = "Write a 2048 web game"
|
||||
new_message = Message(role="User", content=new_idea, cause_by=UserRequirement)
|
||||
news = ltm.find_news([new_message])
|
||||
assert len(news) == 1
|
||||
ltm.add(new_message)
|
||||
|
|
@ -47,8 +49,8 @@ def test_ltm_search():
|
|||
news = ltm_new.find_news([sim_message])
|
||||
assert len(news) == 0
|
||||
|
||||
new_idea = 'Write a Battle City'
|
||||
new_message = Message(role='BOSS', content=new_idea, cause_by=BossRequirement)
|
||||
new_idea = "Write a Battle City"
|
||||
new_message = Message(role="User", content=new_idea, cause_by=UserRequirement)
|
||||
news = ltm_new.find_news([new_message])
|
||||
assert len(news) == 1
|
||||
|
||||
|
|
|
|||
|
|
@ -1,20 +1,22 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the unittests of metagpt/memory/memory_storage.py
|
||||
"""
|
||||
@Desc : the unittests of metagpt/memory/memory_storage.py
|
||||
"""
|
||||
|
||||
|
||||
from typing import List
|
||||
|
||||
from metagpt.actions import UserRequirement, WritePRD
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.memory.memory_storage import MemoryStorage
|
||||
from metagpt.schema import Message
|
||||
from metagpt.actions import BossRequirement
|
||||
from metagpt.actions import WritePRD
|
||||
from metagpt.actions.action_output import ActionOutput
|
||||
|
||||
|
||||
def test_idea_message():
|
||||
idea = 'Write a cli snake game'
|
||||
role_id = 'UTUser1(Product Manager)'
|
||||
message = Message(role='BOSS', content=idea, cause_by=BossRequirement)
|
||||
idea = "Write a cli snake game"
|
||||
role_id = "UTUser1(Product Manager)"
|
||||
message = Message(role="User", content=idea, cause_by=UserRequirement)
|
||||
|
||||
memory_storage: MemoryStorage = MemoryStorage()
|
||||
messages = memory_storage.recover_memory(role_id)
|
||||
|
|
@ -23,13 +25,13 @@ def test_idea_message():
|
|||
memory_storage.add(message)
|
||||
assert memory_storage.is_initialized is True
|
||||
|
||||
sim_idea = 'Write a game of cli snake'
|
||||
sim_message = Message(role='BOSS', content=sim_idea, cause_by=BossRequirement)
|
||||
sim_idea = "Write a game of cli snake"
|
||||
sim_message = Message(role="User", content=sim_idea, cause_by=UserRequirement)
|
||||
new_messages = memory_storage.search(sim_message)
|
||||
assert len(new_messages) == 0 # similar, return []
|
||||
assert len(new_messages) == 0 # similar, return []
|
||||
|
||||
new_idea = 'Write a 2048 web game'
|
||||
new_message = Message(role='BOSS', content=new_idea, cause_by=BossRequirement)
|
||||
new_idea = "Write a 2048 web game"
|
||||
new_message = Message(role="User", content=new_idea, cause_by=UserRequirement)
|
||||
new_messages = memory_storage.search(new_message)
|
||||
assert new_messages[0].content == message.content
|
||||
|
||||
|
|
@ -38,22 +40,15 @@ def test_idea_message():
|
|||
|
||||
|
||||
def test_actionout_message():
|
||||
out_mapping = {
|
||||
'field1': (str, ...),
|
||||
'field2': (List[str], ...)
|
||||
}
|
||||
out_data = {
|
||||
'field1': 'field1 value',
|
||||
'field2': ['field2 value1', 'field2 value2']
|
||||
}
|
||||
ic_obj = ActionOutput.create_model_class('prd', out_mapping)
|
||||
out_mapping = {"field1": (str, ...), "field2": (List[str], ...)}
|
||||
out_data = {"field1": "field1 value", "field2": ["field2 value1", "field2 value2"]}
|
||||
ic_obj = ActionNode.create_model_class("prd", out_mapping)
|
||||
|
||||
role_id = 'UTUser2(Architect)'
|
||||
content = 'The boss has requested the creation of a command-line interface (CLI) snake game'
|
||||
message = Message(content=content,
|
||||
instruct_content=ic_obj(**out_data),
|
||||
role='user',
|
||||
cause_by=WritePRD) # WritePRD as test action
|
||||
role_id = "UTUser2(Architect)"
|
||||
content = "The user has requested the creation of a command-line interface (CLI) snake game"
|
||||
message = Message(
|
||||
content=content, instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD
|
||||
) # WritePRD as test action
|
||||
|
||||
memory_storage: MemoryStorage = MemoryStorage()
|
||||
messages = memory_storage.recover_memory(role_id)
|
||||
|
|
@ -62,19 +57,13 @@ def test_actionout_message():
|
|||
memory_storage.add(message)
|
||||
assert memory_storage.is_initialized is True
|
||||
|
||||
sim_conent = 'The request is command-line interface (CLI) snake game'
|
||||
sim_message = Message(content=sim_conent,
|
||||
instruct_content=ic_obj(**out_data),
|
||||
role='user',
|
||||
cause_by=WritePRD)
|
||||
sim_conent = "The request is command-line interface (CLI) snake game"
|
||||
sim_message = Message(content=sim_conent, instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD)
|
||||
new_messages = memory_storage.search(sim_message)
|
||||
assert len(new_messages) == 0 # similar, return []
|
||||
assert len(new_messages) == 0 # similar, return []
|
||||
|
||||
new_conent = 'Incorporate basic features of a snake game such as scoring and increasing difficulty'
|
||||
new_message = Message(content=new_conent,
|
||||
instruct_content=ic_obj(**out_data),
|
||||
role='user',
|
||||
cause_by=WritePRD)
|
||||
new_conent = "Incorporate basic features of a snake game such as scoring and increasing difficulty"
|
||||
new_message = Message(content=new_conent, instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD)
|
||||
new_messages = memory_storage.search(new_message)
|
||||
assert new_messages[0].content == message.content
|
||||
|
||||
|
|
|
|||
|
|
@ -4,12 +4,14 @@
|
|||
@Time : 2023/9/16 20:03
|
||||
@Author : femto Zheng
|
||||
@File : test_basic_planner.py
|
||||
@Modified By: mashenquan, 2023-11-1. In accordance with Chapter 2.2.1 and 2.2.2 of RFC 116, utilize the new message
|
||||
distribution feature for message handling.
|
||||
"""
|
||||
import pytest
|
||||
from semantic_kernel.core_skills import FileIOSkill, MathSkill, TextSkill, TimeSkill
|
||||
from semantic_kernel.planning.action_planner.action_planner import ActionPlanner
|
||||
|
||||
from metagpt.actions import BossRequirement
|
||||
from metagpt.actions import UserRequirement
|
||||
from metagpt.roles.sk_agent import SkAgent
|
||||
from metagpt.schema import Message
|
||||
|
||||
|
|
@ -23,7 +25,8 @@ async def test_action_planner():
|
|||
role.import_skill(TimeSkill(), "time")
|
||||
role.import_skill(TextSkill(), "text")
|
||||
task = "What is the sum of 110 and 990?"
|
||||
role.recv(Message(content=task, cause_by=BossRequirement))
|
||||
|
||||
role.put_message(Message(content=task, cause_by=UserRequirement))
|
||||
await role._observe()
|
||||
await role._think() # it will choose mathskill.Add
|
||||
assert "1100" == (await role._act()).content
|
||||
|
|
|
|||
|
|
@ -4,11 +4,13 @@
|
|||
@Time : 2023/9/16 20:03
|
||||
@Author : femto Zheng
|
||||
@File : test_basic_planner.py
|
||||
@Modified By: mashenquan, 2023-11-1. In accordance with Chapter 2.2.1 and 2.2.2 of RFC 116, utilize the new message
|
||||
distribution feature for message handling.
|
||||
"""
|
||||
import pytest
|
||||
from semantic_kernel.core_skills import TextSkill
|
||||
|
||||
from metagpt.actions import BossRequirement
|
||||
from metagpt.actions import UserRequirement
|
||||
from metagpt.const import SKILL_DIRECTORY
|
||||
from metagpt.roles.sk_agent import SkAgent
|
||||
from metagpt.schema import Message
|
||||
|
|
@ -26,7 +28,8 @@ async def test_basic_planner():
|
|||
role.import_semantic_skill_from_directory(SKILL_DIRECTORY, "WriterSkill")
|
||||
role.import_skill(TextSkill(), "TextSkill")
|
||||
# using BasicPlanner
|
||||
role.recv(Message(content=task, cause_by=BossRequirement))
|
||||
role.put_message(Message(content=task, cause_by=UserRequirement))
|
||||
await role._observe()
|
||||
await role._think()
|
||||
# assuming sk_agent will think he needs WriterSkill.Brainstorm and WriterSkill.Translate
|
||||
assert "WriterSkill.Brainstorm" in role.plan.generated_plan.result
|
||||
|
|
|
|||
|
|
@ -10,6 +10,6 @@ from metagpt.schema import Message
|
|||
|
||||
|
||||
def test_message():
|
||||
message = Message(role='user', content='wtf')
|
||||
assert 'role' in message.to_dict()
|
||||
assert 'user' in str(message)
|
||||
message = Message(role="user", content="wtf")
|
||||
assert "role" in message.to_dict()
|
||||
assert "user" in str(message)
|
||||
|
|
|
|||
|
|
@ -6,6 +6,6 @@ def test_message():
|
|||
llm = SparkAPI()
|
||||
|
||||
logger.info(llm.ask('只回答"收到了"这三个字。'))
|
||||
result = llm.ask('写一篇五百字的日记')
|
||||
result = llm.ask("写一篇五百字的日记")
|
||||
logger.info(result)
|
||||
assert len(result) > 100
|
||||
|
|
|
|||
|
|
@ -5,10 +5,10 @@
|
|||
@Author : alexanderwu
|
||||
@File : mock.py
|
||||
"""
|
||||
from metagpt.actions import BossRequirement, WriteDesign, WritePRD, WriteTasks
|
||||
from metagpt.actions import UserRequirement, WriteDesign, WritePRD, WriteTasks
|
||||
from metagpt.schema import Message
|
||||
|
||||
BOSS_REQUIREMENT = """开发一个基于大语言模型与私有知识库的搜索引擎,希望可以基于大语言模型进行搜索总结"""
|
||||
USER_REQUIREMENT = """开发一个基于大语言模型与私有知识库的搜索引擎,希望可以基于大语言模型进行搜索总结"""
|
||||
|
||||
DETAIL_REQUIREMENT = """需求:开发一个基于LLM(大语言模型)与私有知识库的搜索引擎,希望有几点能力
|
||||
1. 用户可以在私有知识库进行搜索,再根据大语言模型进行总结,输出的结果包括了总结
|
||||
|
|
@ -71,7 +71,7 @@ PRD = '''## 原始需求
|
|||
```
|
||||
'''
|
||||
|
||||
SYSTEM_DESIGN = '''## Python package name
|
||||
SYSTEM_DESIGN = """## Project name
|
||||
```python
|
||||
"smart_search_engine"
|
||||
```
|
||||
|
|
@ -94,7 +94,7 @@ SYSTEM_DESIGN = '''## Python package name
|
|||
]
|
||||
```
|
||||
|
||||
## Data structures and interface definitions
|
||||
## Data structures and interfaces
|
||||
```mermaid
|
||||
classDiagram
|
||||
class Main {
|
||||
|
|
@ -149,10 +149,10 @@ sequenceDiagram
|
|||
S-->>SE: return summary
|
||||
SE-->>M: return summary
|
||||
```
|
||||
'''
|
||||
"""
|
||||
|
||||
|
||||
TASKS = '''## Logic Analysis
|
||||
TASKS = """## Logic Analysis
|
||||
|
||||
在这个项目中,所有的模块都依赖于“SearchEngine”类,这是主入口,其他的模块(Index、Ranking和Summary)都通过它交互。另外,"Index"类又依赖于"KnowledgeBase"类,因为它需要从知识库中获取数据。
|
||||
|
||||
|
|
@ -181,7 +181,7 @@ task_list = [
|
|||
]
|
||||
```
|
||||
这个任务列表首先定义了最基础的模块,然后是依赖这些模块的模块,最后是辅助模块。可以根据团队的能力和资源,同时开发多个任务,只要满足依赖关系。例如,在开发"search.py"之前,可以同时开发"knowledge_base.py"、"index.py"、"ranking.py"和"summary.py"。
|
||||
'''
|
||||
"""
|
||||
|
||||
|
||||
TASKS_TOMATO_CLOCK = '''## Required Python third-party packages: Provided in requirements.txt format
|
||||
|
|
@ -224,35 +224,35 @@ task_list = [
|
|||
TASK = """smart_search_engine/knowledge_base.py"""
|
||||
|
||||
STRS_FOR_PARSING = [
|
||||
"""
|
||||
"""
|
||||
## 1
|
||||
```python
|
||||
a
|
||||
```
|
||||
""",
|
||||
"""
|
||||
"""
|
||||
##2
|
||||
```python
|
||||
"a"
|
||||
```
|
||||
""",
|
||||
"""
|
||||
"""
|
||||
## 3
|
||||
```python
|
||||
a = "a"
|
||||
```
|
||||
""",
|
||||
"""
|
||||
"""
|
||||
## 4
|
||||
```python
|
||||
a = 'a'
|
||||
```
|
||||
"""
|
||||
""",
|
||||
]
|
||||
|
||||
|
||||
class MockMessages:
|
||||
req = Message(role="Boss", content=BOSS_REQUIREMENT, cause_by=BossRequirement)
|
||||
req = Message(role="User", content=USER_REQUIREMENT, cause_by=UserRequirement)
|
||||
prd = Message(role="Product Manager", content=PRD, cause_by=WritePRD)
|
||||
system_design = Message(role="Architect", content=SYSTEM_DESIGN, cause_by=WriteDesign)
|
||||
tasks = Message(role="Project Manager", content=TASKS, cause_by=WriteTasks)
|
||||
|
|
|
|||
|
|
@ -4,6 +4,8 @@
|
|||
@Time : 2023/5/20 14:37
|
||||
@Author : alexanderwu
|
||||
@File : test_architect.py
|
||||
@Modified By: mashenquan, 2023-11-1. In accordance with Chapter 2.2.1 and 2.2.2 of RFC 116, utilize the new message
|
||||
distribution feature for message handling.
|
||||
"""
|
||||
import pytest
|
||||
|
||||
|
|
@ -15,7 +17,7 @@ from tests.metagpt.roles.mock import MockMessages
|
|||
@pytest.mark.asyncio
|
||||
async def test_architect():
|
||||
role = Architect()
|
||||
role.recv(MockMessages.req)
|
||||
rsp = await role.handle(MockMessages.prd)
|
||||
role.put_message(MockMessages.req)
|
||||
rsp = await role.run(MockMessages.prd)
|
||||
logger.info(rsp)
|
||||
assert len(rsp.content) > 0
|
||||
|
|
|
|||
|
|
@ -4,6 +4,8 @@
|
|||
@Time : 2023/5/12 10:14
|
||||
@Author : alexanderwu
|
||||
@File : test_engineer.py
|
||||
@Modified By: mashenquan, 2023-11-1. In accordance with Chapter 2.2.1 and 2.2.2 of RFC 116, utilize the new message
|
||||
distribution feature for message handling.
|
||||
"""
|
||||
import pytest
|
||||
|
||||
|
|
@ -22,10 +24,10 @@ from tests.metagpt.roles.mock import (
|
|||
async def test_engineer():
|
||||
engineer = Engineer()
|
||||
|
||||
engineer.recv(MockMessages.req)
|
||||
engineer.recv(MockMessages.prd)
|
||||
engineer.recv(MockMessages.system_design)
|
||||
rsp = await engineer.handle(MockMessages.tasks)
|
||||
engineer.put_message(MockMessages.req)
|
||||
engineer.put_message(MockMessages.prd)
|
||||
engineer.put_message(MockMessages.system_design)
|
||||
rsp = await engineer.run(MockMessages.tasks)
|
||||
|
||||
logger.info(rsp)
|
||||
assert "all done." == rsp.content
|
||||
|
|
@ -35,13 +37,13 @@ def test_parse_str():
|
|||
for idx, i in enumerate(STRS_FOR_PARSING):
|
||||
text = CodeParser.parse_str(f"{idx+1}", i)
|
||||
# logger.info(text)
|
||||
assert text == 'a'
|
||||
assert text == "a"
|
||||
|
||||
|
||||
def test_parse_blocks():
|
||||
tasks = CodeParser.parse_blocks(TASKS)
|
||||
logger.info(tasks.keys())
|
||||
assert 'Task list' in tasks.keys()
|
||||
assert "Task list" in tasks.keys()
|
||||
|
||||
|
||||
target_list = [
|
||||
|
|
|
|||
|
|
@ -9,8 +9,8 @@
|
|||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import pandas as pd
|
||||
import pytest
|
||||
|
||||
from metagpt.roles.invoice_ocr_assistant import InvoiceOCRAssistant
|
||||
from metagpt.schema import Message
|
||||
|
|
@ -24,82 +24,39 @@ from metagpt.schema import Message
|
|||
"Invoicing date",
|
||||
Path("../../data/invoices/invoice-1.pdf"),
|
||||
Path("../../../data/invoice_table/invoice-1.xlsx"),
|
||||
[
|
||||
{
|
||||
"收款人": "小明",
|
||||
"城市": "深圳市",
|
||||
"总费用/元": 412.00,
|
||||
"开票日期": "2023年02月03日"
|
||||
}
|
||||
]
|
||||
[{"收款人": "小明", "城市": "深圳市", "总费用/元": 412.00, "开票日期": "2023年02月03日"}],
|
||||
),
|
||||
(
|
||||
"Invoicing date",
|
||||
Path("../../data/invoices/invoice-2.png"),
|
||||
Path("../../../data/invoice_table/invoice-2.xlsx"),
|
||||
[
|
||||
{
|
||||
"收款人": "铁头",
|
||||
"城市": "广州市",
|
||||
"总费用/元": 898.00,
|
||||
"开票日期": "2023年03月17日"
|
||||
}
|
||||
]
|
||||
[{"收款人": "铁头", "城市": "广州市", "总费用/元": 898.00, "开票日期": "2023年03月17日"}],
|
||||
),
|
||||
(
|
||||
"Invoicing date",
|
||||
Path("../../data/invoices/invoice-3.jpg"),
|
||||
Path("../../../data/invoice_table/invoice-3.xlsx"),
|
||||
[
|
||||
{
|
||||
"收款人": "夏天",
|
||||
"城市": "福州市",
|
||||
"总费用/元": 2462.00,
|
||||
"开票日期": "2023年08月26日"
|
||||
}
|
||||
]
|
||||
[{"收款人": "夏天", "城市": "福州市", "总费用/元": 2462.00, "开票日期": "2023年08月26日"}],
|
||||
),
|
||||
(
|
||||
"Invoicing date",
|
||||
Path("../../data/invoices/invoice-4.zip"),
|
||||
Path("../../../data/invoice_table/invoice-4.xlsx"),
|
||||
[
|
||||
{
|
||||
"收款人": "小明",
|
||||
"城市": "深圳市",
|
||||
"总费用/元": 412.00,
|
||||
"开票日期": "2023年02月03日"
|
||||
},
|
||||
{
|
||||
"收款人": "铁头",
|
||||
"城市": "广州市",
|
||||
"总费用/元": 898.00,
|
||||
"开票日期": "2023年03月17日"
|
||||
},
|
||||
{
|
||||
"收款人": "夏天",
|
||||
"城市": "福州市",
|
||||
"总费用/元": 2462.00,
|
||||
"开票日期": "2023年08月26日"
|
||||
}
|
||||
]
|
||||
{"收款人": "小明", "城市": "深圳市", "总费用/元": 412.00, "开票日期": "2023年02月03日"},
|
||||
{"收款人": "铁头", "城市": "广州市", "总费用/元": 898.00, "开票日期": "2023年03月17日"},
|
||||
{"收款人": "夏天", "城市": "福州市", "总费用/元": 2462.00, "开票日期": "2023年08月26日"},
|
||||
],
|
||||
),
|
||||
]
|
||||
],
|
||||
)
|
||||
async def test_invoice_ocr_assistant(
|
||||
query: str,
|
||||
invoice_path: Path,
|
||||
invoice_table_path: Path,
|
||||
expected_result: list[dict]
|
||||
query: str, invoice_path: Path, invoice_table_path: Path, expected_result: list[dict]
|
||||
):
|
||||
invoice_path = Path.cwd() / invoice_path
|
||||
role = InvoiceOCRAssistant()
|
||||
await role.run(Message(
|
||||
content=query,
|
||||
instruct_content={"file_path": invoice_path}
|
||||
))
|
||||
await role.run(Message(content=query, instruct_content={"file_path": invoice_path}))
|
||||
invoice_table_path = Path.cwd() / invoice_table_path
|
||||
df = pd.read_excel(invoice_table_path)
|
||||
dict_result = df.to_dict(orient='records')
|
||||
dict_result = df.to_dict(orient="records")
|
||||
assert dict_result == expected_result
|
||||
|
||||
|
|
|
|||
|
|
@ -11,10 +11,12 @@ async def mock_llm_ask(self, prompt: str, system_msgs):
|
|||
if "Please provide up to 2 necessary keywords" in prompt:
|
||||
return '["dataiku", "datarobot"]'
|
||||
elif "Provide up to 4 queries related to your research topic" in prompt:
|
||||
return '["Dataiku machine learning platform", "DataRobot AI platform comparison", ' \
|
||||
return (
|
||||
'["Dataiku machine learning platform", "DataRobot AI platform comparison", '
|
||||
'"Dataiku vs DataRobot features", "Dataiku and DataRobot use cases"]'
|
||||
)
|
||||
elif "sort the remaining search results" in prompt:
|
||||
return '[1,2]'
|
||||
return "[1,2]"
|
||||
elif "Not relevant." in prompt:
|
||||
return "Not relevant" if random() > 0.5 else prompt[-100:]
|
||||
elif "provide a detailed research report" in prompt:
|
||||
|
|
|
|||
11
tests/metagpt/roles/test_role.py
Normal file
11
tests/metagpt/roles/test_role.py
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : unittest of Role
|
||||
|
||||
from metagpt.roles.role import Role
|
||||
|
||||
|
||||
def test_role_desc():
|
||||
role = Role(profile="Sales", desc="Best Seller")
|
||||
assert role.profile == "Sales"
|
||||
assert role._setting.desc == "Best Seller"
|
||||
|
|
@ -12,10 +12,7 @@ from metagpt.roles.tutorial_assistant import TutorialAssistant
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
("language", "topic"),
|
||||
[("Chinese", "Write a tutorial about Python")]
|
||||
)
|
||||
@pytest.mark.parametrize(("language", "topic"), [("Chinese", "Write a tutorial about Python")])
|
||||
async def test_tutorial_assistant(language: str, topic: str):
|
||||
topic = "Write a tutorial about MySQL"
|
||||
role = TutorialAssistant(language=language)
|
||||
|
|
@ -24,4 +21,4 @@ async def test_tutorial_assistant(language: str, topic: str):
|
|||
title = filename.split("/")[-1].split(".")[0]
|
||||
async with aiofiles.open(filename, mode="r") as reader:
|
||||
content = await reader.read()
|
||||
assert content.startswith(f"# {title}")
|
||||
assert content.startswith(f"# {title}")
|
||||
|
|
|
|||
|
|
@ -1,10 +1,9 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Date : 2023/7/22 02:40
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Author : stellahong (stellahong@deepwisdom.ai)
|
||||
#
|
||||
from metagpt.team import Team
|
||||
from metagpt.roles import ProductManager
|
||||
|
||||
from metagpt.team import Team
|
||||
from tests.metagpt.roles.ui_role import UI
|
||||
|
||||
|
||||
|
|
@ -18,5 +17,5 @@ async def test_ui_role(idea: str, investment: float = 3.0, n_round: int = 5):
|
|||
company = Team()
|
||||
company.hire([ProductManager(), UI()])
|
||||
company.invest(investment)
|
||||
company.start_project(idea)
|
||||
company.run_project(idea)
|
||||
await company.run(n_round=n_round)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Date : 2023/7/15 16:40
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Author : stellahong (stellahong@deepwisdom.ai)
|
||||
# @Desc :
|
||||
import os
|
||||
import re
|
||||
|
|
@ -8,51 +8,48 @@ from functools import wraps
|
|||
from importlib import import_module
|
||||
|
||||
from metagpt.actions import Action, ActionOutput, WritePRD
|
||||
from metagpt.const import WORKSPACE_ROOT
|
||||
|
||||
# from metagpt.const import WORKSPACE_ROOT
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.logs import logger
|
||||
from metagpt.roles import Role
|
||||
from metagpt.schema import Message
|
||||
from metagpt.tools.sd_engine import SDEngine
|
||||
|
||||
PROMPT_TEMPLATE = """
|
||||
# Context
|
||||
{context}
|
||||
|
||||
## Format example
|
||||
{format_example}
|
||||
-----
|
||||
Role: You are a UserInterface Designer; the goal is to finish a UI design according to PRD, give a design description, and select specified elements and UI style.
|
||||
Requirements: Based on the context, fill in the following missing information, provide detailed HTML and CSS code
|
||||
Attention: Use '##' to split sections, not '#', and '## <SECTION_NAME>' SHOULD WRITE BEFORE the code and triple quote.
|
||||
|
||||
## UI Design Description:Provide as Plain text, place the design objective here
|
||||
## Selected Elements:Provide as Plain text, up to 5 specified elements, clear and simple
|
||||
## HTML Layout:Provide as Plain text, use standard HTML code
|
||||
## CSS Styles (styles.css):Provide as Plain text,use standard css code
|
||||
## Anything UNCLEAR:Provide as Plain text. Make clear here.
|
||||
|
||||
## Role
|
||||
You are a UserInterface Designer; the goal is to finish a UI design according to PRD, give a design description, and select specified elements and UI style.
|
||||
"""
|
||||
|
||||
FORMAT_EXAMPLE = """
|
||||
UI_DESIGN_DESC = ActionNode(
|
||||
key="UI Design Desc",
|
||||
expected_type=str,
|
||||
instruction="place the design objective here",
|
||||
example="Snake games are classic and addictive games with simple yet engaging elements. Here are the main elements"
|
||||
" commonly found in snake games",
|
||||
)
|
||||
|
||||
## UI Design Description
|
||||
```Snake games are classic and addictive games with simple yet engaging elements. Here are the main elements commonly found in snake games ```
|
||||
SELECTED_ELEMENTS = ActionNode(
|
||||
key="Selected Elements",
|
||||
expected_type=list[str],
|
||||
instruction="up to 5 specified elements, clear and simple",
|
||||
example=[
|
||||
"Game Grid: The game grid is a rectangular...",
|
||||
"Snake: The player controls a snake that moves across the grid...",
|
||||
"Food: Food items (often represented as small objects or differently colored blocks)",
|
||||
"Score: The player's score increases each time the snake eats a piece of food. The longer the snake becomes, the higher the score.",
|
||||
"Game Over: The game ends when the snake collides with itself or an obstacle. At this point, the player's final score is displayed, and they are given the option to restart the game.",
|
||||
],
|
||||
)
|
||||
|
||||
## Selected Elements
|
||||
|
||||
Game Grid: The game grid is a rectangular...
|
||||
|
||||
Snake: The player controls a snake that moves across the grid...
|
||||
|
||||
Food: Food items (often represented as small objects or differently colored blocks)
|
||||
|
||||
Score: The player's score increases each time the snake eats a piece of food. The longer the snake becomes, the higher the score.
|
||||
|
||||
Game Over: The game ends when the snake collides with itself or an obstacle. At this point, the player's final score is displayed, and they are given the option to restart the game.
|
||||
|
||||
|
||||
## HTML Layout
|
||||
<!DOCTYPE html>
|
||||
HTML_LAYOUT = ActionNode(
|
||||
key="HTML Layout",
|
||||
expected_type=str,
|
||||
instruction="use standard HTML code",
|
||||
example="""<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
|
|
@ -69,9 +66,14 @@ Game Over: The game ends when the snake collides with itself or an obstacle. At
|
|||
</div>
|
||||
</body>
|
||||
</html>
|
||||
""",
|
||||
)
|
||||
|
||||
## CSS Styles (styles.css)
|
||||
body {
|
||||
CSS_STYLES = ActionNode(
|
||||
key="CSS Styles",
|
||||
expected_type=str,
|
||||
instruction="use standard css code",
|
||||
example="""body {
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
|
|
@ -119,19 +121,25 @@ body {
|
|||
color: #ff0000;
|
||||
display: none;
|
||||
}
|
||||
""",
|
||||
)
|
||||
|
||||
## Anything UNCLEAR
|
||||
There are no unclear points.
|
||||
ANYTHING_UNCLEAR = ActionNode(
|
||||
key="Anything UNCLEAR",
|
||||
expected_type=str,
|
||||
instruction="Mention any aspects of the project that are unclear and try to clarify them.",
|
||||
example="...",
|
||||
)
|
||||
|
||||
"""
|
||||
NODES = [
|
||||
UI_DESIGN_DESC,
|
||||
SELECTED_ELEMENTS,
|
||||
HTML_LAYOUT,
|
||||
CSS_STYLES,
|
||||
ANYTHING_UNCLEAR,
|
||||
]
|
||||
|
||||
OUTPUT_MAPPING = {
|
||||
"UI Design Description": (str, ...),
|
||||
"Selected Elements": (str, ...),
|
||||
"HTML Layout": (str, ...),
|
||||
"CSS Styles (styles.css)": (str, ...),
|
||||
"Anything UNCLEAR": (str, ...),
|
||||
}
|
||||
UI_DESIGN_NODE = ActionNode.from_children("UI_DESIGN", NODES)
|
||||
|
||||
|
||||
def load_engine(func):
|
||||
|
|
@ -214,17 +222,15 @@ class UIDesign(Action):
|
|||
logger.info("Finish icon design using StableDiffusion API")
|
||||
|
||||
async def _save(self, css_content, html_content):
|
||||
save_dir = WORKSPACE_ROOT / "resources" / "codes"
|
||||
save_dir = CONFIG.workspace_path / "resources" / "codes"
|
||||
if not os.path.exists(save_dir):
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
# Save CSS and HTML content to files
|
||||
css_file_path = save_dir / "ui_design.css"
|
||||
html_file_path = save_dir / "ui_design.html"
|
||||
|
||||
with open(css_file_path, "w") as css_file:
|
||||
css_file.write(css_content)
|
||||
with open(html_file_path, "w") as html_file:
|
||||
html_file.write(html_content)
|
||||
css_file_path.write_text(css_content)
|
||||
html_file_path.write_text(html_content)
|
||||
|
||||
async def run(self, requirements: list[Message], *args, **kwargs) -> ActionOutput:
|
||||
"""Run the UI Design action."""
|
||||
|
|
@ -232,9 +238,9 @@ class UIDesign(Action):
|
|||
context = requirements[-1].content
|
||||
ui_design_draft = self.parse_requirement(context=context)
|
||||
# todo: parse requirements str
|
||||
prompt = PROMPT_TEMPLATE.format(context=ui_design_draft, format_example=FORMAT_EXAMPLE)
|
||||
prompt = PROMPT_TEMPLATE.format(context=ui_design_draft)
|
||||
logger.info(prompt)
|
||||
ui_describe = await self._aask_v1(prompt, "ui_design", OUTPUT_MAPPING)
|
||||
ui_describe = await UI_DESIGN_NODE.fill(prompt)
|
||||
logger.info(ui_describe.content)
|
||||
logger.info(ui_describe.instruct_content)
|
||||
css = self.parse_css_code(context=ui_describe.content)
|
||||
|
|
|
|||
4
tests/metagpt/serialize_deserialize/__init__.py
Normal file
4
tests/metagpt/serialize_deserialize/__init__.py
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Date : 11/22/2023 11:48 AM
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Desc :
|
||||
27
tests/metagpt/serialize_deserialize/test_action.py
Normal file
27
tests/metagpt/serialize_deserialize/test_action.py
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Date : 11/22/2023 11:48 AM
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Desc :
|
||||
import pytest
|
||||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.llm import LLM
|
||||
|
||||
|
||||
def test_action_serialize():
|
||||
action = Action()
|
||||
ser_action_dict = action.dict()
|
||||
assert "name" in ser_action_dict
|
||||
# assert "llm" not in ser_action_dict # not export
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_action_deserialize():
|
||||
action = Action()
|
||||
serialized_data = action.dict()
|
||||
|
||||
new_action = Action(**serialized_data)
|
||||
|
||||
assert new_action.name == ""
|
||||
assert new_action.llm == LLM()
|
||||
assert len(await new_action._aask("who are you")) > 0
|
||||
|
|
@ -0,0 +1,28 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Date : 11/26/2023 2:04 PM
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Desc :
|
||||
import pytest
|
||||
|
||||
from metagpt.actions.action import Action
|
||||
from metagpt.roles.architect import Architect
|
||||
|
||||
|
||||
def test_architect_serialize():
|
||||
role = Architect()
|
||||
ser_role_dict = role.dict(by_alias=True)
|
||||
assert "name" in ser_role_dict
|
||||
assert "_states" in ser_role_dict
|
||||
assert "_actions" in ser_role_dict
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_architect_deserialize():
|
||||
role = Architect()
|
||||
ser_role_dict = role.dict(by_alias=True)
|
||||
new_role = Architect(**ser_role_dict)
|
||||
# new_role = Architect.deserialize(ser_role_dict)
|
||||
assert new_role.name == "Bob"
|
||||
assert len(new_role._actions) == 1
|
||||
assert isinstance(new_role._actions[0], Action)
|
||||
await new_role._actions[0].run(with_messages="write a cli snake game")
|
||||
88
tests/metagpt/serialize_deserialize/test_environment.py
Normal file
88
tests/metagpt/serialize_deserialize/test_environment.py
Normal file
|
|
@ -0,0 +1,88 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
|
||||
import shutil
|
||||
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.actions.add_requirement import UserRequirement
|
||||
from metagpt.actions.project_management import WriteTasks
|
||||
from metagpt.environment import Environment
|
||||
from metagpt.roles.project_manager import ProjectManager
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.common import any_to_str
|
||||
from tests.metagpt.serialize_deserialize.test_serdeser_base import (
|
||||
ActionOK,
|
||||
RoleC,
|
||||
serdeser_path,
|
||||
)
|
||||
|
||||
|
||||
def test_env_serialize():
|
||||
env = Environment()
|
||||
ser_env_dict = env.dict()
|
||||
assert "roles" in ser_env_dict
|
||||
|
||||
|
||||
def test_env_deserialize():
|
||||
env = Environment()
|
||||
env.publish_message(message=Message(content="test env serialize"))
|
||||
ser_env_dict = env.dict()
|
||||
new_env = Environment(**ser_env_dict)
|
||||
assert len(new_env.roles) == 0
|
||||
assert len(new_env.history) == 25
|
||||
|
||||
|
||||
def test_environment_serdeser():
|
||||
out_mapping = {"field1": (list[str], ...)}
|
||||
out_data = {"field1": ["field1 value1", "field1 value2"]}
|
||||
ic_obj = ActionNode.create_model_class("prd", out_mapping)
|
||||
|
||||
message = Message(
|
||||
content="prd", instruct_content=ic_obj(**out_data), role="product manager", cause_by=any_to_str(UserRequirement)
|
||||
)
|
||||
|
||||
environment = Environment()
|
||||
role_c = RoleC()
|
||||
environment.add_role(role_c)
|
||||
environment.publish_message(message)
|
||||
|
||||
ser_data = environment.dict()
|
||||
assert ser_data["roles"]["Role C"]["name"] == "RoleC"
|
||||
|
||||
new_env: Environment = Environment(**ser_data)
|
||||
assert len(new_env.roles) == 1
|
||||
|
||||
assert list(new_env.roles.values())[0]._states == list(environment.roles.values())[0]._states
|
||||
assert list(new_env.roles.values())[0]._actions == list(environment.roles.values())[0]._actions
|
||||
assert isinstance(list(environment.roles.values())[0]._actions[0], ActionOK)
|
||||
assert type(list(new_env.roles.values())[0]._actions[0]) == ActionOK
|
||||
|
||||
|
||||
def test_environment_serdeser_v2():
|
||||
environment = Environment()
|
||||
pm = ProjectManager()
|
||||
environment.add_role(pm)
|
||||
|
||||
ser_data = environment.dict()
|
||||
|
||||
new_env: Environment = Environment(**ser_data)
|
||||
role = new_env.get_role(pm.profile)
|
||||
assert isinstance(role, ProjectManager)
|
||||
assert isinstance(role._actions[0], WriteTasks)
|
||||
assert isinstance(list(new_env.roles.values())[0]._actions[0], WriteTasks)
|
||||
|
||||
|
||||
def test_environment_serdeser_save():
|
||||
environment = Environment()
|
||||
role_c = RoleC()
|
||||
|
||||
shutil.rmtree(serdeser_path.joinpath("team"), ignore_errors=True)
|
||||
|
||||
stg_path = serdeser_path.joinpath("team", "environment")
|
||||
environment.add_role(role_c)
|
||||
environment.serialize(stg_path)
|
||||
|
||||
new_env: Environment = Environment.deserialize(stg_path)
|
||||
assert len(new_env.roles) == 1
|
||||
assert type(list(new_env.roles.values())[0]._actions[0]) == ActionOK
|
||||
63
tests/metagpt/serialize_deserialize/test_memory.py
Normal file
63
tests/metagpt/serialize_deserialize/test_memory.py
Normal file
|
|
@ -0,0 +1,63 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : unittest of memory
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.actions.add_requirement import UserRequirement
|
||||
from metagpt.actions.design_api import WriteDesign
|
||||
from metagpt.memory.memory import Memory
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.common import any_to_str
|
||||
from tests.metagpt.serialize_deserialize.test_serdeser_base import serdeser_path
|
||||
|
||||
|
||||
def test_memory_serdeser():
|
||||
msg1 = Message(role="Boss", content="write a snake game", cause_by=UserRequirement)
|
||||
|
||||
out_mapping = {"field2": (list[str], ...)}
|
||||
out_data = {"field2": ["field2 value1", "field2 value2"]}
|
||||
ic_obj = ActionNode.create_model_class("system_design", out_mapping)
|
||||
msg2 = Message(
|
||||
role="Architect", instruct_content=ic_obj(**out_data), content="system design content", cause_by=WriteDesign
|
||||
)
|
||||
|
||||
memory = Memory()
|
||||
memory.add_batch([msg1, msg2])
|
||||
ser_data = memory.dict()
|
||||
|
||||
new_memory = Memory(**ser_data)
|
||||
assert new_memory.count() == 2
|
||||
new_msg2 = new_memory.get(2)[0]
|
||||
assert isinstance(new_msg2, BaseModel)
|
||||
assert isinstance(new_memory.storage[-1], BaseModel)
|
||||
assert new_memory.storage[-1].cause_by == any_to_str(WriteDesign)
|
||||
assert new_msg2.role == "Boss"
|
||||
|
||||
|
||||
def test_memory_serdeser_save():
|
||||
msg1 = Message(role="User", content="write a 2048 game", cause_by=UserRequirement)
|
||||
|
||||
out_mapping = {"field1": (list[str], ...)}
|
||||
out_data = {"field1": ["field1 value1", "field1 value2"]}
|
||||
ic_obj = ActionNode.create_model_class("system_design", out_mapping)
|
||||
msg2 = Message(
|
||||
role="Architect", instruct_content=ic_obj(**out_data), content="system design content", cause_by=WriteDesign
|
||||
)
|
||||
|
||||
memory = Memory()
|
||||
memory.add_batch([msg1, msg2])
|
||||
|
||||
stg_path = serdeser_path.joinpath("team", "environment")
|
||||
memory.serialize(stg_path)
|
||||
assert stg_path.joinpath("memory.json").exists()
|
||||
|
||||
new_memory = Memory.deserialize(stg_path)
|
||||
assert new_memory.count() == 2
|
||||
new_msg2 = new_memory.get(1)[0]
|
||||
assert new_msg2.instruct_content.field1 == ["field1 value1", "field1 value2"]
|
||||
assert new_msg2.cause_by == any_to_str(WriteDesign)
|
||||
assert len(new_memory.index) == 2
|
||||
|
||||
stg_path.joinpath("memory.json").unlink()
|
||||
21
tests/metagpt/serialize_deserialize/test_product_manager.py
Normal file
21
tests/metagpt/serialize_deserialize/test_product_manager.py
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Date : 11/26/2023 2:07 PM
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Desc :
|
||||
import pytest
|
||||
|
||||
from metagpt.actions.action import Action
|
||||
from metagpt.roles.product_manager import ProductManager
|
||||
from metagpt.schema import Message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_product_manager_deserialize():
|
||||
role = ProductManager()
|
||||
ser_role_dict = role.dict(by_alias=True)
|
||||
new_role = ProductManager(**ser_role_dict)
|
||||
|
||||
assert new_role.name == "Alice"
|
||||
assert len(new_role._actions) == 2
|
||||
assert isinstance(new_role._actions[0], Action)
|
||||
await new_role._actions[0].run([Message(content="write a cli snake game")])
|
||||
30
tests/metagpt/serialize_deserialize/test_project_manager.py
Normal file
30
tests/metagpt/serialize_deserialize/test_project_manager.py
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Date : 11/26/2023 2:06 PM
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Desc :
|
||||
import pytest
|
||||
|
||||
from metagpt.actions.action import Action
|
||||
from metagpt.actions.project_management import WriteTasks
|
||||
from metagpt.roles.project_manager import ProjectManager
|
||||
|
||||
|
||||
def test_project_manager_serialize():
|
||||
role = ProjectManager()
|
||||
ser_role_dict = role.dict(by_alias=True)
|
||||
assert "name" in ser_role_dict
|
||||
assert "_states" in ser_role_dict
|
||||
assert "_actions" in ser_role_dict
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_project_manager_deserialize():
|
||||
role = ProjectManager()
|
||||
ser_role_dict = role.dict(by_alias=True)
|
||||
|
||||
new_role = ProjectManager(**ser_role_dict)
|
||||
assert new_role.name == "Eve"
|
||||
assert len(new_role._actions) == 1
|
||||
assert isinstance(new_role._actions[0], Action)
|
||||
assert isinstance(new_role._actions[0], WriteTasks)
|
||||
# await new_role._actions[0].run(context="write a cli snake game")
|
||||
96
tests/metagpt/serialize_deserialize/test_role.py
Normal file
96
tests/metagpt/serialize_deserialize/test_role.py
Normal file
|
|
@ -0,0 +1,96 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Date : 11/23/2023 4:49 PM
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Desc :
|
||||
|
||||
import shutil
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.actions import WriteCode
|
||||
from metagpt.actions.add_requirement import UserRequirement
|
||||
from metagpt.const import SERDESER_PATH
|
||||
from metagpt.logs import logger
|
||||
from metagpt.roles.engineer import Engineer
|
||||
from metagpt.roles.product_manager import ProductManager
|
||||
from metagpt.roles.role import Role
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.common import format_trackback_info
|
||||
from tests.metagpt.serialize_deserialize.test_serdeser_base import (
|
||||
RoleA,
|
||||
RoleB,
|
||||
RoleC,
|
||||
serdeser_path,
|
||||
)
|
||||
|
||||
|
||||
def test_roles():
|
||||
role_a = RoleA()
|
||||
assert len(role_a._rc.watch) == 1
|
||||
role_b = RoleB()
|
||||
assert len(role_a._rc.watch) == 1
|
||||
assert len(role_b._rc.watch) == 1
|
||||
|
||||
|
||||
def test_role_serialize():
|
||||
role = Role()
|
||||
ser_role_dict = role.dict(by_alias=True)
|
||||
assert "name" in ser_role_dict
|
||||
assert "_states" in ser_role_dict
|
||||
assert "_actions" in ser_role_dict
|
||||
|
||||
|
||||
def test_engineer_serialize():
|
||||
role = Engineer()
|
||||
ser_role_dict = role.dict(by_alias=True)
|
||||
assert "name" in ser_role_dict
|
||||
assert "_states" in ser_role_dict
|
||||
assert "_actions" in ser_role_dict
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_engineer_deserialize():
|
||||
role = Engineer(use_code_review=True)
|
||||
ser_role_dict = role.dict(by_alias=True)
|
||||
|
||||
new_role = Engineer(**ser_role_dict)
|
||||
assert new_role.name == "Alex"
|
||||
assert new_role.use_code_review is True
|
||||
assert len(new_role._actions) == 1
|
||||
assert isinstance(new_role._actions[0], WriteCode)
|
||||
# await new_role._actions[0].run(context="write a cli snake game", filename="test_code")
|
||||
|
||||
|
||||
def test_role_serdeser_save():
|
||||
stg_path_prefix = serdeser_path.joinpath("team", "environment", "roles")
|
||||
shutil.rmtree(serdeser_path.joinpath("team"), ignore_errors=True)
|
||||
|
||||
pm = ProductManager()
|
||||
role_tag = f"{pm.__class__.__name__}_{pm.name}"
|
||||
stg_path = stg_path_prefix.joinpath(role_tag)
|
||||
pm.serialize(stg_path)
|
||||
|
||||
new_pm = Role.deserialize(stg_path)
|
||||
assert new_pm.name == pm.name
|
||||
assert len(new_pm.get_memories(1)) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_role_serdeser_interrupt():
|
||||
role_c = RoleC()
|
||||
shutil.rmtree(SERDESER_PATH.joinpath("team"), ignore_errors=True)
|
||||
|
||||
stg_path = SERDESER_PATH.joinpath("team", "environment", "roles", f"{role_c.__class__.__name__}_{role_c.name}")
|
||||
try:
|
||||
await role_c.run(with_message=Message(content="demo", cause_by=UserRequirement))
|
||||
except Exception:
|
||||
logger.error(f"Exception in `role_a.run`, detail: {format_trackback_info()}")
|
||||
role_c.serialize(stg_path)
|
||||
|
||||
assert role_c._rc.memory.count() == 1
|
||||
|
||||
new_role_a: Role = Role.deserialize(stg_path)
|
||||
assert new_role_a._rc.state == 1
|
||||
|
||||
with pytest.raises(Exception):
|
||||
await role_c.run(with_message=Message(content="demo", cause_by=UserRequirement))
|
||||
38
tests/metagpt/serialize_deserialize/test_schema.py
Normal file
38
tests/metagpt/serialize_deserialize/test_schema.py
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : unittest of schema ser&deser
|
||||
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.actions.write_code import WriteCode
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.common import any_to_str
|
||||
from tests.metagpt.serialize_deserialize.test_serdeser_base import MockMessage
|
||||
|
||||
|
||||
def test_message_serdeser():
|
||||
out_mapping = {"field3": (str, ...), "field4": (list[str], ...)}
|
||||
out_data = {"field3": "field3 value3", "field4": ["field4 value1", "field4 value2"]}
|
||||
ic_obj = ActionNode.create_model_class("code", out_mapping)
|
||||
|
||||
message = Message(content="code", instruct_content=ic_obj(**out_data), role="engineer", cause_by=WriteCode)
|
||||
ser_data = message.dict()
|
||||
assert ser_data["cause_by"] == "metagpt.actions.write_code.WriteCode"
|
||||
assert ser_data["instruct_content"]["class"] == "code"
|
||||
|
||||
new_message = Message(**ser_data)
|
||||
assert new_message.cause_by == any_to_str(WriteCode)
|
||||
assert new_message.cause_by in [any_to_str(WriteCode)]
|
||||
assert new_message.instruct_content == ic_obj(**out_data)
|
||||
|
||||
|
||||
def test_message_without_postprocess():
|
||||
"""to explain `instruct_content` should be postprocessed"""
|
||||
out_mapping = {"field1": (list[str], ...)}
|
||||
out_data = {"field1": ["field1 value1", "field1 value2"]}
|
||||
ic_obj = ActionNode.create_model_class("code", out_mapping)
|
||||
message = MockMessage(content="code", instruct_content=ic_obj(**out_data))
|
||||
ser_data = message.dict()
|
||||
assert ser_data["instruct_content"] == {"field1": ["field1 value1", "field1 value2"]}
|
||||
|
||||
new_message = MockMessage(**ser_data)
|
||||
assert new_message.instruct_content != ic_obj(**out_data)
|
||||
87
tests/metagpt/serialize_deserialize/test_serdeser_base.py
Normal file
87
tests/metagpt/serialize_deserialize/test_serdeser_base.py
Normal file
|
|
@ -0,0 +1,87 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : base test actions / roles used in unittest
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from metagpt.actions import Action, ActionOutput
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.actions.add_requirement import UserRequirement
|
||||
from metagpt.roles.role import Role, RoleReactMode
|
||||
|
||||
serdeser_path = Path(__file__).absolute().parent.joinpath("..", "..", "data", "serdeser_storage")
|
||||
|
||||
|
||||
class MockMessage(BaseModel):
|
||||
"""to test normal dict without postprocess"""
|
||||
|
||||
content: str = ""
|
||||
instruct_content: BaseModel = Field(default=None)
|
||||
|
||||
|
||||
class ActionPass(Action):
|
||||
name: str = Field(default="ActionPass")
|
||||
|
||||
async def run(self, messages: list["Message"]) -> ActionOutput:
|
||||
await asyncio.sleep(5) # sleep to make other roles can watch the executed Message
|
||||
output_mapping = {"result": (str, ...)}
|
||||
pass_class = ActionNode.create_model_class("pass", output_mapping)
|
||||
pass_output = ActionOutput("ActionPass run passed", pass_class(**{"result": "pass result"}))
|
||||
|
||||
return pass_output
|
||||
|
||||
|
||||
class ActionOK(Action):
|
||||
name: str = Field(default="ActionOK")
|
||||
|
||||
async def run(self, messages: list["Message"]) -> str:
|
||||
await asyncio.sleep(5)
|
||||
return "ok"
|
||||
|
||||
|
||||
class ActionRaise(Action):
|
||||
name: str = Field(default="ActionRaise")
|
||||
|
||||
async def run(self, messages: list["Message"]) -> str:
|
||||
raise RuntimeError("parse error in ActionRaise")
|
||||
|
||||
|
||||
class RoleA(Role):
|
||||
name: str = Field(default="RoleA")
|
||||
profile: str = Field(default="Role A")
|
||||
goal: str = "RoleA's goal"
|
||||
constraints: str = "RoleA's constraints"
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super(RoleA, self).__init__(**kwargs)
|
||||
self._init_actions([ActionPass])
|
||||
self._watch([UserRequirement])
|
||||
|
||||
|
||||
class RoleB(Role):
|
||||
name: str = Field(default="RoleB")
|
||||
profile: str = Field(default="Role B")
|
||||
goal: str = "RoleB's goal"
|
||||
constraints: str = "RoleB's constraints"
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super(RoleB, self).__init__(**kwargs)
|
||||
self._init_actions([ActionOK, ActionRaise])
|
||||
self._watch([ActionPass])
|
||||
self._rc.react_mode = RoleReactMode.BY_ORDER
|
||||
|
||||
|
||||
class RoleC(Role):
|
||||
name: str = Field(default="RoleC")
|
||||
profile: str = Field(default="Role C")
|
||||
goal: str = "RoleC's goal"
|
||||
constraints: str = "RoleC's constraints"
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super(RoleC, self).__init__(**kwargs)
|
||||
self._init_actions([ActionOK, ActionRaise])
|
||||
self._watch([UserRequirement])
|
||||
self._rc.react_mode = RoleReactMode.BY_ORDER
|
||||
135
tests/metagpt/serialize_deserialize/test_team.py
Normal file
135
tests/metagpt/serialize_deserialize/test_team.py
Normal file
|
|
@ -0,0 +1,135 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Date : 11/27/2023 10:07 AM
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Desc :
|
||||
|
||||
import shutil
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.const import SERDESER_PATH
|
||||
from metagpt.logs import logger
|
||||
from metagpt.roles import Architect, ProductManager, ProjectManager
|
||||
from metagpt.team import Team
|
||||
from tests.metagpt.serialize_deserialize.test_serdeser_base import (
|
||||
ActionOK,
|
||||
RoleA,
|
||||
RoleB,
|
||||
RoleC,
|
||||
serdeser_path,
|
||||
)
|
||||
|
||||
|
||||
def test_team_deserialize():
|
||||
company = Team()
|
||||
|
||||
pm = ProductManager()
|
||||
arch = Architect()
|
||||
company.hire(
|
||||
[
|
||||
pm,
|
||||
arch,
|
||||
ProjectManager(),
|
||||
]
|
||||
)
|
||||
assert len(company.env.get_roles()) == 3
|
||||
ser_company = company.dict()
|
||||
new_company = Team(**ser_company)
|
||||
|
||||
assert len(new_company.env.get_roles()) == 3
|
||||
assert new_company.env.get_role(pm.profile) is not None
|
||||
|
||||
new_pm = new_company.env.get_role(pm.profile)
|
||||
assert type(new_pm) == ProductManager
|
||||
assert new_company.env.get_role(pm.profile) is not None
|
||||
assert new_company.env.get_role(arch.profile) is not None
|
||||
|
||||
|
||||
def test_team_serdeser_save():
|
||||
company = Team()
|
||||
company.hire([RoleC()])
|
||||
|
||||
stg_path = serdeser_path.joinpath("team")
|
||||
shutil.rmtree(stg_path, ignore_errors=True)
|
||||
|
||||
company.serialize(stg_path=stg_path)
|
||||
|
||||
new_company = Team.deserialize(stg_path)
|
||||
|
||||
assert len(new_company.env.roles) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_team_recover():
|
||||
idea = "write a snake game"
|
||||
stg_path = SERDESER_PATH.joinpath("team")
|
||||
shutil.rmtree(stg_path, ignore_errors=True)
|
||||
|
||||
company = Team()
|
||||
role_c = RoleC()
|
||||
company.hire([role_c])
|
||||
company.run_project(idea)
|
||||
await company.run(n_round=4)
|
||||
|
||||
ser_data = company.dict()
|
||||
new_company = Team(**ser_data)
|
||||
|
||||
new_role_c = new_company.env.get_role(role_c.profile)
|
||||
# assert new_role_c._rc.memory == role_c._rc.memory # TODO
|
||||
assert new_role_c._rc.env != role_c._rc.env # TODO
|
||||
assert type(list(new_company.env.roles.values())[0]._actions[0]) == ActionOK
|
||||
|
||||
new_company.run_project(idea)
|
||||
await new_company.run(n_round=4)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_team_recover_save():
|
||||
idea = "write a 2048 web game"
|
||||
stg_path = SERDESER_PATH.joinpath("team")
|
||||
shutil.rmtree(stg_path, ignore_errors=True)
|
||||
|
||||
company = Team()
|
||||
role_c = RoleC()
|
||||
company.hire([role_c])
|
||||
company.run_project(idea)
|
||||
await company.run(n_round=4)
|
||||
|
||||
new_company = Team.deserialize(stg_path)
|
||||
new_role_c = new_company.env.get_role(role_c.profile)
|
||||
# assert new_role_c._rc.memory == role_c._rc.memory
|
||||
assert new_role_c._rc.env != role_c._rc.env
|
||||
assert new_role_c.recovered != role_c.recovered # here cause previous ut is `!=`
|
||||
assert new_role_c._rc.todo != role_c._rc.todo # serialize exclude `_rc.todo`
|
||||
assert new_role_c._rc.news != role_c._rc.news # serialize exclude `_rc.news`
|
||||
|
||||
new_company.run_project(idea)
|
||||
await new_company.run(n_round=4)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_team_recover_multi_roles_save():
|
||||
idea = "write a snake game"
|
||||
stg_path = SERDESER_PATH.joinpath("team")
|
||||
shutil.rmtree(stg_path, ignore_errors=True)
|
||||
|
||||
role_a = RoleA()
|
||||
role_b = RoleB()
|
||||
|
||||
assert role_a.subscription == {"tests.metagpt.serialize_deserialize.test_serdeser_base.RoleA", "RoleA"}
|
||||
assert role_b.subscription == {"tests.metagpt.serialize_deserialize.test_serdeser_base.RoleB", "RoleB"}
|
||||
assert role_b._rc.watch == {"tests.metagpt.serialize_deserialize.test_serdeser_base.ActionPass"}
|
||||
|
||||
company = Team()
|
||||
company.hire([role_a, role_b])
|
||||
company.run_project(idea)
|
||||
await company.run(n_round=4)
|
||||
|
||||
logger.info("Team recovered")
|
||||
|
||||
new_company = Team.deserialize(stg_path)
|
||||
new_company.run_project(idea)
|
||||
|
||||
assert new_company.env.get_role(role_b.profile)._rc.state == 1
|
||||
|
||||
await new_company.run(n_round=4)
|
||||
32
tests/metagpt/serialize_deserialize/test_write_code.py
Normal file
32
tests/metagpt/serialize_deserialize/test_write_code.py
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Date : 11/23/2023 10:56 AM
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Desc :
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.actions import WriteCode
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.schema import CodingContext, Document
|
||||
|
||||
|
||||
def test_write_design_serialize():
|
||||
action = WriteCode()
|
||||
ser_action_dict = action.dict()
|
||||
assert ser_action_dict["name"] == "WriteCode"
|
||||
# assert "llm" in ser_action_dict # not export
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_code_deserialize():
|
||||
context = CodingContext(
|
||||
filename="test_code.py", design_doc=Document(content="write add function to calculate two numbers")
|
||||
)
|
||||
doc = Document(content=context.json())
|
||||
action = WriteCode(context=doc)
|
||||
serialized_data = action.dict()
|
||||
new_action = WriteCode(**serialized_data)
|
||||
|
||||
assert new_action.name == "WriteCode"
|
||||
assert new_action.llm == LLM()
|
||||
await action.run()
|
||||
|
|
@ -0,0 +1,32 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : unittest of WriteCodeReview SerDeser
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.actions import WriteCodeReview
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.schema import CodingContext, Document
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_code_review_deserialize():
|
||||
code_content = """
|
||||
def div(a: int, b: int = 0):
|
||||
return a / b
|
||||
"""
|
||||
context = CodingContext(
|
||||
filename="test_op.py",
|
||||
design_doc=Document(content="divide two numbers"),
|
||||
code_doc=Document(content=code_content),
|
||||
)
|
||||
|
||||
action = WriteCodeReview(context=context)
|
||||
serialized_data = action.dict()
|
||||
assert serialized_data["name"] == "WriteCodeReview"
|
||||
|
||||
new_action = WriteCodeReview(**serialized_data)
|
||||
|
||||
assert new_action.name == "WriteCodeReview"
|
||||
assert new_action.llm == LLM()
|
||||
await new_action.run()
|
||||
42
tests/metagpt/serialize_deserialize/test_write_design.py
Normal file
42
tests/metagpt/serialize_deserialize/test_write_design.py
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Date : 11/22/2023 8:19 PM
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Desc :
|
||||
import pytest
|
||||
|
||||
from metagpt.actions import WriteDesign, WriteTasks
|
||||
from metagpt.llm import LLM
|
||||
|
||||
|
||||
def test_write_design_serialize():
|
||||
action = WriteDesign()
|
||||
ser_action_dict = action.dict()
|
||||
assert "name" in ser_action_dict
|
||||
# assert "llm" in ser_action_dict # not export
|
||||
|
||||
|
||||
def test_write_task_serialize():
|
||||
action = WriteTasks()
|
||||
ser_action_dict = action.dict()
|
||||
assert "name" in ser_action_dict
|
||||
# assert "llm" in ser_action_dict # not export
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_design_deserialize():
|
||||
action = WriteDesign()
|
||||
serialized_data = action.dict()
|
||||
new_action = WriteDesign(**serialized_data)
|
||||
assert new_action.name == ""
|
||||
assert new_action.llm == LLM()
|
||||
await new_action.run(with_messages="write a cli snake game")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_task_deserialize():
|
||||
action = WriteTasks()
|
||||
serialized_data = action.dict()
|
||||
new_action = WriteTasks(**serialized_data)
|
||||
assert new_action.name == "CreateTasks"
|
||||
assert new_action.llm == LLM()
|
||||
await new_action.run(with_messages="write a cli snake game")
|
||||
28
tests/metagpt/serialize_deserialize/test_write_prd.py
Normal file
28
tests/metagpt/serialize_deserialize/test_write_prd.py
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Date : 11/22/2023 1:47 PM
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Desc :
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.actions import WritePRD
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.schema import Message
|
||||
|
||||
|
||||
def test_action_serialize():
|
||||
action = WritePRD()
|
||||
ser_action_dict = action.dict()
|
||||
assert "name" in ser_action_dict
|
||||
# assert "llm" in ser_action_dict # not export
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_action_deserialize():
|
||||
action = WritePRD()
|
||||
serialized_data = action.dict()
|
||||
new_action = WritePRD(**serialized_data)
|
||||
assert new_action.name == ""
|
||||
assert new_action.llm == LLM()
|
||||
action_output = await new_action.run(with_messages=Message(content="write a cli snake game"))
|
||||
assert len(action_output.content) > 0
|
||||
|
|
@ -6,15 +6,19 @@
|
|||
@File : test_environment.py
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.actions import BossRequirement
|
||||
from metagpt.actions import UserRequirement
|
||||
from metagpt.environment import Environment
|
||||
from metagpt.logs import logger
|
||||
from metagpt.manager import Manager
|
||||
from metagpt.roles import Architect, ProductManager, Role
|
||||
from metagpt.schema import Message
|
||||
|
||||
serdeser_path = Path(__file__).absolute().parent.joinpath("../data/serdeser_storage")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def env():
|
||||
|
|
@ -22,34 +26,33 @@ def env():
|
|||
|
||||
|
||||
def test_add_role(env: Environment):
|
||||
role = ProductManager("Alice", "product manager", "create a new product", "limited resources")
|
||||
role = ProductManager(
|
||||
name="Alice", profile="product manager", goal="create a new product", constraints="limited resources"
|
||||
)
|
||||
env.add_role(role)
|
||||
assert env.get_role(role.profile) == role
|
||||
|
||||
|
||||
def test_get_roles(env: Environment):
|
||||
role1 = Role("Alice", "product manager", "create a new product", "limited resources")
|
||||
role2 = Role("Bob", "engineer", "develop the new product", "short deadline")
|
||||
role1 = Role(name="Alice", profile="product manager", goal="create a new product", constraints="limited resources")
|
||||
role2 = Role(name="Bob", profile="engineer", goal="develop the new product", constraints="short deadline")
|
||||
env.add_role(role1)
|
||||
env.add_role(role2)
|
||||
roles = env.get_roles()
|
||||
assert roles == {role1.profile: role1, role2.profile: role2}
|
||||
|
||||
|
||||
def test_set_manager(env: Environment):
|
||||
manager = Manager()
|
||||
env.set_manager(manager)
|
||||
assert env.manager == manager
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_and_process_message(env: Environment):
|
||||
product_manager = ProductManager("Alice", "Product Manager", "做AI Native产品", "资源有限")
|
||||
architect = Architect("Bob", "Architect", "设计一个可用、高效、较低成本的系统,包括数据结构与接口", "资源有限,需要节省成本")
|
||||
product_manager = ProductManager(name="Alice", profile="Product Manager", goal="做AI Native产品", constraints="资源有限")
|
||||
architect = Architect(
|
||||
name="Bob", profile="Architect", goal="设计一个可用、高效、较低成本的系统,包括数据结构与接口", constraints="资源有限,需要节省成本"
|
||||
)
|
||||
|
||||
env.add_roles([product_manager, architect])
|
||||
|
||||
env.set_manager(Manager())
|
||||
env.publish_message(Message(role="BOSS", content="需要一个基于LLM做总结的搜索引擎", cause_by=BossRequirement))
|
||||
env.publish_message(Message(role="User", content="需要一个基于LLM做总结的搜索引擎", cause_by=UserRequirement))
|
||||
|
||||
await env.run(k=2)
|
||||
logger.info(f"{env.history=}")
|
||||
|
|
|
|||
|
|
@ -14,7 +14,8 @@ from metagpt.logs import logger
|
|||
@pytest.mark.usefixtures("llm_api")
|
||||
class TestGPT:
|
||||
def test_llm_api_ask(self, llm_api):
|
||||
answer = llm_api.ask('hello chatgpt')
|
||||
answer = llm_api.ask("hello chatgpt")
|
||||
logger.info(answer)
|
||||
assert len(answer) > 0
|
||||
|
||||
# def test_gptapi_ask_batch(self, llm_api):
|
||||
|
|
@ -22,22 +23,29 @@ class TestGPT:
|
|||
# assert len(answer) > 0
|
||||
|
||||
def test_llm_api_ask_code(self, llm_api):
|
||||
answer = llm_api.ask_code(['请扮演一个Google Python专家工程师,如果理解,回复明白', '写一个hello world'])
|
||||
answer = llm_api.ask_code(["请扮演一个Google Python专家工程师,如果理解,回复明白", "写一个hello world"])
|
||||
logger.info(answer)
|
||||
assert len(answer) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_api_aask(self, llm_api):
|
||||
answer = await llm_api.aask('hello chatgpt')
|
||||
answer = await llm_api.aask("hello chatgpt")
|
||||
logger.info(answer)
|
||||
assert len(answer) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_api_aask_code(self, llm_api):
|
||||
answer = await llm_api.aask_code(['请扮演一个Google Python专家工程师,如果理解,回复明白', '写一个hello world'])
|
||||
answer = await llm_api.aask_code(["请扮演一个Google Python专家工程师,如果理解,回复明白", "写一个hello world"])
|
||||
logger.info(answer)
|
||||
assert len(answer) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_api_costs(self, llm_api):
|
||||
await llm_api.aask('hello chatgpt')
|
||||
await llm_api.aask("hello chatgpt")
|
||||
costs = llm_api.get_costs()
|
||||
logger.info(costs)
|
||||
assert costs.total_cost > 0
|
||||
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
|
|
@ -18,17 +18,21 @@ def llm():
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_aask(llm):
|
||||
assert len(await llm.aask('hello world')) > 0
|
||||
assert len(await llm.aask("hello world")) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_aask_batch(llm):
|
||||
assert len(await llm.aask_batch(['hi', 'write python hello world.'])) > 0
|
||||
assert len(await llm.aask_batch(["hi", "write python hello world."])) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_acompletion(llm):
|
||||
hello_msg = [{'role': 'user', 'content': 'hello'}]
|
||||
hello_msg = [{"role": "user", "content": "hello"}]
|
||||
assert len(await llm.acompletion(hello_msg)) > 0
|
||||
assert len(await llm.acompletion_batch([hello_msg])) > 0
|
||||
assert len(await llm.acompletion_batch_text([hello_msg])) > 0
|
||||
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
@Time : 2023/5/16 10:57
|
||||
@Author : alexanderwu
|
||||
@File : test_message.py
|
||||
@Modified By: mashenquan, 2023-11-1. Modify coding style.
|
||||
"""
|
||||
import pytest
|
||||
|
||||
|
|
@ -11,26 +12,30 @@ from metagpt.schema import AIMessage, Message, RawMessage, SystemMessage, UserMe
|
|||
|
||||
|
||||
def test_message():
|
||||
msg = Message(role='User', content='WTF')
|
||||
assert msg.to_dict()['role'] == 'User'
|
||||
assert 'User' in str(msg)
|
||||
msg = Message(role="User", content="WTF")
|
||||
assert msg.to_dict()["role"] == "User"
|
||||
assert "User" in str(msg)
|
||||
|
||||
|
||||
def test_all_messages():
|
||||
test_content = 'test_message'
|
||||
test_content = "test_message"
|
||||
msgs = [
|
||||
UserMessage(test_content),
|
||||
SystemMessage(test_content),
|
||||
AIMessage(test_content),
|
||||
Message(test_content, role='QA')
|
||||
Message(test_content, role="QA"),
|
||||
]
|
||||
for msg in msgs:
|
||||
assert msg.content == test_content
|
||||
|
||||
|
||||
def test_raw_message():
|
||||
msg = RawMessage(role='user', content='raw')
|
||||
assert msg['role'] == 'user'
|
||||
assert msg['content'] == 'raw'
|
||||
msg = RawMessage(role="user", content="raw")
|
||||
assert msg["role"] == "user"
|
||||
assert msg["content"] == "raw"
|
||||
with pytest.raises(KeyError):
|
||||
assert msg['1'] == 1, "KeyError: '1'"
|
||||
assert msg["1"] == 1, "KeyError: '1'"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
342
tests/metagpt/test_prompt.py
Normal file
342
tests/metagpt/test_prompt.py
Normal file
|
|
@ -0,0 +1,342 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/5/11 14:45
|
||||
@Author : alexanderwu
|
||||
@File : test_llm.py
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.llm import LLM
|
||||
|
||||
CODE_REVIEW_SMALLEST_CONTEXT = """
|
||||
## game.js
|
||||
```Code
|
||||
// game.js
|
||||
class Game {
|
||||
constructor() {
|
||||
this.board = this.createEmptyBoard();
|
||||
this.score = 0;
|
||||
this.bestScore = 0;
|
||||
}
|
||||
|
||||
createEmptyBoard() {
|
||||
const board = [];
|
||||
for (let i = 0; i < 4; i++) {
|
||||
board[i] = [0, 0, 0, 0];
|
||||
}
|
||||
return board;
|
||||
}
|
||||
|
||||
startGame() {
|
||||
this.board = this.createEmptyBoard();
|
||||
this.score = 0;
|
||||
this.addRandomTile();
|
||||
this.addRandomTile();
|
||||
}
|
||||
|
||||
addRandomTile() {
|
||||
let emptyCells = [];
|
||||
for (let r = 0; r < 4; r++) {
|
||||
for (let c = 0; c < 4; c++) {
|
||||
if (this.board[r][c] === 0) {
|
||||
emptyCells.push({ r, c });
|
||||
}
|
||||
}
|
||||
}
|
||||
if (emptyCells.length > 0) {
|
||||
let randomCell = emptyCells[Math.floor(Math.random() * emptyCells.length)];
|
||||
this.board[randomCell.r][randomCell.c] = Math.random() < 0.9 ? 2 : 4;
|
||||
}
|
||||
}
|
||||
|
||||
move(direction) {
|
||||
// This function will handle the logic for moving tiles
|
||||
// in the specified direction and merging them
|
||||
// It will also update the score and add a new random tile if the move is successful
|
||||
// The actual implementation of this function is complex and would require
|
||||
// a significant amount of code to handle all the cases for moving and merging tiles
|
||||
// For the purposes of this example, we will not implement the full logic
|
||||
// Instead, we will just call addRandomTile to simulate a move
|
||||
this.addRandomTile();
|
||||
}
|
||||
|
||||
getBoard() {
|
||||
return this.board;
|
||||
}
|
||||
|
||||
getScore() {
|
||||
return this.score;
|
||||
}
|
||||
|
||||
getBestScore() {
|
||||
return this.bestScore;
|
||||
}
|
||||
|
||||
setBestScore(score) {
|
||||
this.bestScore = score;
|
||||
}
|
||||
}
|
||||
|
||||
```
|
||||
"""
|
||||
|
||||
MOVE_DRAFT = """
|
||||
## move function draft
|
||||
|
||||
```javascript
|
||||
move(direction) {
|
||||
let moved = false;
|
||||
switch (direction) {
|
||||
case 'up':
|
||||
for (let c = 0; c < 4; c++) {
|
||||
for (let r = 1; r < 4; r++) {
|
||||
if (this.board[r][c] !== 0) {
|
||||
let row = r;
|
||||
while (row > 0 && this.board[row - 1][c] === 0) {
|
||||
this.board[row - 1][c] = this.board[row][c];
|
||||
this.board[row][c] = 0;
|
||||
row--;
|
||||
moved = true;
|
||||
}
|
||||
if (row > 0 && this.board[row - 1][c] === this.board[row][c]) {
|
||||
this.board[row - 1][c] *= 2;
|
||||
this.board[row][c] = 0;
|
||||
this.score += this.board[row - 1][c];
|
||||
moved = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
break;
|
||||
case 'down':
|
||||
// Implement logic for moving tiles down
|
||||
// Similar to the 'up' case but iterating in reverse order
|
||||
// and checking for merging in the opposite direction
|
||||
break;
|
||||
case 'left':
|
||||
// Implement logic for moving tiles left
|
||||
// Similar to the 'up' case but iterating over columns first
|
||||
// and checking for merging in the opposite direction
|
||||
break;
|
||||
case 'right':
|
||||
// Implement logic for moving tiles right
|
||||
// Similar to the 'up' case but iterating over columns in reverse order
|
||||
// and checking for merging in the opposite direction
|
||||
break;
|
||||
}
|
||||
|
||||
if (moved) {
|
||||
this.addRandomTile();
|
||||
}
|
||||
}
|
||||
```
|
||||
"""
|
||||
|
||||
FUNCTION_TO_MERMAID_CLASS = """
|
||||
## context
|
||||
```
|
||||
class UIDesign(Action):
|
||||
#Class representing the UI Design action.
|
||||
def __init__(self, name, context=None, llm=None):
|
||||
super().__init__(name, context, llm) # 需要调用LLM进一步丰富UI设计的prompt
|
||||
@parse
|
||||
def parse_requirement(self, context: str):
|
||||
#Parse UI Design draft from the context using regex.
|
||||
pattern = r"## UI Design draft.*?\n(.*?)## Anything UNCLEAR"
|
||||
return context, pattern
|
||||
@parse
|
||||
def parse_ui_elements(self, context: str):
|
||||
#Parse Selected Elements from the context using regex.
|
||||
pattern = r"## Selected Elements.*?\n(.*?)## HTML Layout"
|
||||
return context, pattern
|
||||
@parse
|
||||
def parse_css_code(self, context: str):
|
||||
pattern = r"```css.*?\n(.*?)## Anything UNCLEAR"
|
||||
return context, pattern
|
||||
@parse
|
||||
def parse_html_code(self, context: str):
|
||||
pattern = r"```html.*?\n(.*?)```"
|
||||
return context, pattern
|
||||
async def draw_icons(self, context, *args, **kwargs):
|
||||
#Draw icons using SDEngine.
|
||||
engine = SDEngine()
|
||||
icon_prompts = self.parse_ui_elements(context)
|
||||
icons = icon_prompts.split("\n")
|
||||
icons = [s for s in icons if len(s.strip()) > 0]
|
||||
prompts_batch = []
|
||||
for icon_prompt in icons:
|
||||
# fixme: 添加icon lora
|
||||
prompt = engine.construct_payload(icon_prompt + ".<lora:WZ0710_AW81e-3_30e3b128d64T32_goon0.5>")
|
||||
prompts_batch.append(prompt)
|
||||
await engine.run_t2i(prompts_batch)
|
||||
logger.info("Finish icon design using StableDiffusion API")
|
||||
async def _save(self, css_content, html_content):
|
||||
save_dir = CONFIG.workspace_path / "resources" / "codes"
|
||||
if not os.path.exists(save_dir):
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
# Save CSS and HTML content to files
|
||||
css_file_path = save_dir / "ui_design.css"
|
||||
html_file_path = save_dir / "ui_design.html"
|
||||
with open(css_file_path, "w") as css_file:
|
||||
css_file.write(css_content)
|
||||
with open(html_file_path, "w") as html_file:
|
||||
html_file.write(html_content)
|
||||
async def run(self, requirements: list[Message], *args, **kwargs) -> ActionOutput:
|
||||
#Run the UI Design action.
|
||||
# fixme: update prompt (根据需求细化prompt)
|
||||
context = requirements[-1].content
|
||||
ui_design_draft = self.parse_requirement(context=context)
|
||||
# todo: parse requirements str
|
||||
prompt = PROMPT_TEMPLATE.format(context=ui_design_draft, format_example=FORMAT_EXAMPLE)
|
||||
logger.info(prompt)
|
||||
ui_describe = await self._aask_v1(prompt, "ui_design", OUTPUT_MAPPING)
|
||||
logger.info(ui_describe.content)
|
||||
logger.info(ui_describe.instruct_content)
|
||||
css = self.parse_css_code(context=ui_describe.content)
|
||||
html = self.parse_html_code(context=ui_describe.content)
|
||||
await self._save(css_content=css, html_content=html)
|
||||
await self.draw_icons(ui_describe.content)
|
||||
return ui_describe
|
||||
```
|
||||
-----
|
||||
## format example
|
||||
[CONTENT]
|
||||
{
|
||||
"ClassView": "classDiagram\n class A {\n -int x\n +int y\n -int speed\n -int direction\n +__init__(x: int, y: int, speed: int, direction: int)\n +change_direction(new_direction: int) None\n +move() None\n }\n "
|
||||
}
|
||||
[/CONTENT]
|
||||
## nodes: "<node>: <type> # <comment>"
|
||||
- ClassView: <class 'str'> # Generate the mermaid class diagram corresponding to source code in "context."
|
||||
## constraint
|
||||
- Language: Please use the same language as the user input.
|
||||
- Format: output wrapped inside [CONTENT][/CONTENT] as format example, nothing else.
|
||||
## action
|
||||
Fill in the above nodes(ClassView) based on the format example.
|
||||
"""
|
||||
|
||||
MOVE_FUNCTION = """
|
||||
## move function implementation
|
||||
|
||||
```javascript
|
||||
move(direction) {
|
||||
let moved = false;
|
||||
switch (direction) {
|
||||
case 'up':
|
||||
for (let c = 0; c < 4; c++) {
|
||||
for (let r = 1; r < 4; r++) {
|
||||
if (this.board[r][c] !== 0) {
|
||||
let row = r;
|
||||
while (row > 0 && this.board[row - 1][c] === 0) {
|
||||
this.board[row - 1][c] = this.board[row][c];
|
||||
this.board[row][c] = 0;
|
||||
row--;
|
||||
moved = true;
|
||||
}
|
||||
if (row > 0 && this.board[row - 1][c] === this.board[row][c]) {
|
||||
this.board[row - 1][c] *= 2;
|
||||
this.board[row][c] = 0;
|
||||
this.score += this.board[row - 1][c];
|
||||
moved = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
break;
|
||||
case 'down':
|
||||
for (let c = 0; c < 4; c++) {
|
||||
for (let r = 2; r >= 0; r--) {
|
||||
if (this.board[r][c] !== 0) {
|
||||
let row = r;
|
||||
while (row < 3 && this.board[row + 1][c] === 0) {
|
||||
this.board[row + 1][c] = this.board[row][c];
|
||||
this.board[row][c] = 0;
|
||||
row++;
|
||||
moved = true;
|
||||
}
|
||||
if (row < 3 && this.board[row + 1][c] === this.board[row][c]) {
|
||||
this.board[row + 1][c] *= 2;
|
||||
this.board[row][c] = 0;
|
||||
this.score += this.board[row + 1][c];
|
||||
moved = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
break;
|
||||
case 'left':
|
||||
for (let r = 0; r < 4; r++) {
|
||||
for (let c = 1; c < 4; c++) {
|
||||
if (this.board[r][c] !== 0) {
|
||||
let col = c;
|
||||
while (col > 0 && this.board[r][col - 1] === 0) {
|
||||
this.board[r][col - 1] = this.board[r][col];
|
||||
this.board[r][col] = 0;
|
||||
col--;
|
||||
moved = true;
|
||||
}
|
||||
if (col > 0 && this.board[r][col - 1] === this.board[r][col]) {
|
||||
this.board[r][col - 1] *= 2;
|
||||
this.board[r][col] = 0;
|
||||
this.score += this.board[r][col - 1];
|
||||
moved = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
break;
|
||||
case 'right':
|
||||
for (let r = 0; r < 4; r++) {
|
||||
for (let c = 2; c >= 0; c--) {
|
||||
if (this.board[r][c] !== 0) {
|
||||
let col = c;
|
||||
while (col < 3 && this.board[r][col + 1] === 0) {
|
||||
this.board[r][col + 1] = this.board[r][col];
|
||||
this.board[r][col] = 0;
|
||||
col++;
|
||||
moved = true;
|
||||
}
|
||||
if (col < 3 && this.board[r][col + 1] === this.board[r][col]) {
|
||||
this.board[r][col + 1] *= 2;
|
||||
this.board[r][col] = 0;
|
||||
this.score += this.board[r][col + 1];
|
||||
moved = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
if (moved) {
|
||||
this.addRandomTile();
|
||||
}
|
||||
}
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def llm():
|
||||
return LLM()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_code_review(llm):
|
||||
choices = [
|
||||
"Please review the move function code above. Should it be refactor?",
|
||||
"Please implement the move function",
|
||||
"Please write a draft for the move function in order to implement it",
|
||||
]
|
||||
# prompt = CODE_REVIEW_SMALLEST_CONTEXT+ "\n\n" + MOVE_DRAFT + "\n\n" + choices[1]
|
||||
# rsp = await llm.aask(prompt)
|
||||
|
||||
prompt = CODE_REVIEW_SMALLEST_CONTEXT + "\n\n" + MOVE_FUNCTION + "\n\n" + choices[0]
|
||||
prompt = FUNCTION_TO_MERMAID_CLASS
|
||||
|
||||
_ = await llm.aask(prompt)
|
||||
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# pytest.main([__file__, "-s"])
|
||||
|
|
@ -4,11 +4,98 @@
|
|||
@Time : 2023/5/11 14:44
|
||||
@Author : alexanderwu
|
||||
@File : test_role.py
|
||||
@Modified By: mashenquan, 2023-11-1. In line with Chapter 2.2.1 and 2.2.2 of RFC 116, introduce unit tests for
|
||||
the utilization of the new message distribution feature in message handling.
|
||||
@Modified By: mashenquan, 2023-11-4. According to the routing feature plan in Chapter 2.2.3.2 of RFC 113, the routing
|
||||
functionality is to be consolidated into the `Environment` class.
|
||||
"""
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from metagpt.actions import Action, ActionOutput, UserRequirement
|
||||
from metagpt.environment import Environment
|
||||
from metagpt.roles import Role
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.common import any_to_str
|
||||
|
||||
|
||||
def test_role_desc():
|
||||
i = Role(profile='Sales', desc='Best Seller')
|
||||
assert i.profile == 'Sales'
|
||||
assert i._setting.desc == 'Best Seller'
|
||||
class MockAction(Action):
|
||||
async def run(self, messages, *args, **kwargs):
|
||||
assert messages
|
||||
return ActionOutput(content=messages[-1].content, instruct_content=messages[-1])
|
||||
|
||||
|
||||
class MockRole(Role):
|
||||
def __init__(self, name="", profile="", goal="", constraints="", desc=""):
|
||||
super().__init__(name=name, profile=profile, goal=goal, constraints=constraints, desc=desc)
|
||||
self._init_actions([MockAction()])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_react():
|
||||
class Input(BaseModel):
|
||||
name: str
|
||||
profile: str
|
||||
goal: str
|
||||
constraints: str
|
||||
desc: str
|
||||
subscription: str
|
||||
|
||||
inputs = [
|
||||
{
|
||||
"name": "A",
|
||||
"profile": "Tester",
|
||||
"goal": "Test",
|
||||
"constraints": "constraints",
|
||||
"desc": "desc",
|
||||
"subscription": "start",
|
||||
}
|
||||
]
|
||||
|
||||
for i in inputs:
|
||||
seed = Input(**i)
|
||||
role = MockRole(
|
||||
name=seed.name, profile=seed.profile, goal=seed.goal, constraints=seed.constraints, desc=seed.desc
|
||||
)
|
||||
role.subscribe({seed.subscription})
|
||||
assert role._rc.watch == {any_to_str(UserRequirement)}
|
||||
assert role.name == seed.name
|
||||
assert role.profile == seed.profile
|
||||
assert role._setting.goal == seed.goal
|
||||
assert role._setting.constraints == seed.constraints
|
||||
assert role._setting.desc == seed.desc
|
||||
assert role.is_idle
|
||||
env = Environment()
|
||||
env.add_role(role)
|
||||
assert env.get_subscription(role) == {seed.subscription}
|
||||
env.publish_message(Message(content="test", msg_to=seed.subscription))
|
||||
assert not role.is_idle
|
||||
while not env.is_idle:
|
||||
await env.run()
|
||||
assert role.is_idle
|
||||
env.publish_message(Message(content="test", cause_by=seed.subscription))
|
||||
assert not role.is_idle
|
||||
while not env.is_idle:
|
||||
await env.run()
|
||||
assert role.is_idle
|
||||
tag = uuid.uuid4().hex
|
||||
role.subscribe({tag})
|
||||
assert env.get_subscription(role) == {tag}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_msg_to():
|
||||
m = Message(content="a", send_to=["a", MockRole, Message])
|
||||
assert m.send_to == {"a", any_to_str(MockRole), any_to_str(Message)}
|
||||
|
||||
m = Message(content="a", cause_by=MockAction, send_to={"a", MockRole, Message})
|
||||
assert m.send_to == {"a", any_to_str(MockRole), any_to_str(Message)}
|
||||
|
||||
m = Message(content="a", send_to=("a", MockRole, Message))
|
||||
assert m.send_to == {"a", any_to_str(MockRole), any_to_str(Message)}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
|
|
@ -4,18 +4,97 @@
|
|||
@Time : 2023/5/20 10:40
|
||||
@Author : alexanderwu
|
||||
@File : test_schema.py
|
||||
@Modified By: mashenquan, 2023-11-1. In line with Chapter 2.2.1 and 2.2.2 of RFC 116, introduce unit tests for
|
||||
the utilization of the new feature of `Message` class.
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.actions.write_code import WriteCode
|
||||
from metagpt.schema import AIMessage, Message, SystemMessage, UserMessage
|
||||
from metagpt.utils.common import any_to_str
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
def test_messages():
|
||||
test_content = 'test_message'
|
||||
test_content = "test_message"
|
||||
msgs = [
|
||||
UserMessage(test_content),
|
||||
SystemMessage(test_content),
|
||||
AIMessage(test_content),
|
||||
Message(test_content, role='QA')
|
||||
UserMessage(content=test_content),
|
||||
SystemMessage(content=test_content),
|
||||
AIMessage(content=test_content),
|
||||
Message(content=test_content, role="QA"),
|
||||
]
|
||||
text = str(msgs)
|
||||
roles = ['user', 'system', 'assistant', 'QA']
|
||||
roles = ["user", "system", "assistant", "QA"]
|
||||
assert all([i in text for i in roles])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
def test_message():
|
||||
m = Message(content="a", role="v1")
|
||||
v = m.dump()
|
||||
d = json.loads(v)
|
||||
assert d
|
||||
assert d.get("content") == "a"
|
||||
assert d.get("role") == "v1"
|
||||
m.role = "v2"
|
||||
v = m.dump()
|
||||
assert v
|
||||
m = Message.load(v)
|
||||
assert m.content == "a"
|
||||
assert m.role == "v2"
|
||||
|
||||
m = Message(content="a", role="b", cause_by="c", x="d", send_to="c")
|
||||
assert m.content == "a"
|
||||
assert m.role == "b"
|
||||
assert m.send_to == {"c"}
|
||||
assert m.cause_by == "c"
|
||||
|
||||
m.cause_by = "Message"
|
||||
assert m.cause_by == "Message"
|
||||
m.cause_by = Action
|
||||
assert m.cause_by == any_to_str(Action)
|
||||
m.cause_by = Action()
|
||||
assert m.cause_by == any_to_str(Action)
|
||||
m.content = "b"
|
||||
assert m.content == "b"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
def test_routes():
|
||||
m = Message(content="a", role="b", cause_by="c", x="d", send_to="c")
|
||||
m.send_to = "b"
|
||||
assert m.send_to == {"b"}
|
||||
m.send_to = {"e", Action}
|
||||
assert m.send_to == {"e", any_to_str(Action)}
|
||||
|
||||
|
||||
def test_message_serdeser():
|
||||
out_mapping = {"field3": (str, ...), "field4": (list[str], ...)}
|
||||
out_data = {"field3": "field3 value3", "field4": ["field4 value1", "field4 value2"]}
|
||||
ic_obj = ActionNode.create_model_class("code", out_mapping)
|
||||
|
||||
message = Message(content="code", instruct_content=ic_obj(**out_data), role="engineer", cause_by=WriteCode)
|
||||
message_dict = message.dict()
|
||||
assert message_dict["cause_by"] == "metagpt.actions.write_code.WriteCode"
|
||||
assert message_dict["instruct_content"] == {
|
||||
"class": "code",
|
||||
"mapping": {"field3": "(<class 'str'>, Ellipsis)", "field4": "(list[str], Ellipsis)"},
|
||||
"value": {"field3": "field3 value3", "field4": ["field4 value1", "field4 value2"]},
|
||||
}
|
||||
|
||||
new_message = Message(**message_dict)
|
||||
assert new_message.content == message.content
|
||||
assert new_message.instruct_content == message.instruct_content
|
||||
assert new_message.cause_by == message.cause_by
|
||||
assert new_message.instruct_content.field3 == out_data["field3"]
|
||||
|
||||
message = Message(content="code")
|
||||
message_dict = message.dict()
|
||||
new_message = Message(**message_dict)
|
||||
assert new_message.instruct_content is None
|
||||
assert new_message.cause_by == "metagpt.actions.add_requirement.UserRequirement"
|
||||
|
|
|
|||
|
|
@ -1,19 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/5/15 11:40
|
||||
@Author : alexanderwu
|
||||
@File : test_software_company.py
|
||||
"""
|
||||
import pytest
|
||||
|
||||
from metagpt.logs import logger
|
||||
from metagpt.team import Team
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_team():
|
||||
company = Team()
|
||||
company.start_project("做一个基础搜索引擎,可以支持知识库")
|
||||
history = await company.run(n_round=5)
|
||||
logger.info(history)
|
||||
28
tests/metagpt/test_startup.py
Normal file
28
tests/metagpt/test_startup.py
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/5/15 11:40
|
||||
@Author : alexanderwu
|
||||
@File : test_startup.py
|
||||
"""
|
||||
import pytest
|
||||
from typer.testing import CliRunner
|
||||
|
||||
from metagpt.logs import logger
|
||||
from metagpt.team import Team
|
||||
|
||||
runner = CliRunner()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_team():
|
||||
# FIXME: we're now using "metagpt" cli, so the entrance should be replaced instead.
|
||||
company = Team()
|
||||
company.run_project("做一个基础搜索引擎,可以支持知识库")
|
||||
history = await company.run(n_round=5)
|
||||
logger.info(history)
|
||||
|
||||
|
||||
# def test_startup():
|
||||
# args = ["Make a 2048 game"]
|
||||
# result = runner.invoke(app, args)
|
||||
102
tests/metagpt/test_subscription.py
Normal file
102
tests/metagpt/test_subscription.py
Normal file
|
|
@ -0,0 +1,102 @@
|
|||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.roles import Role
|
||||
from metagpt.schema import Message
|
||||
from metagpt.subscription import SubscriptionRunner
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subscription_run():
|
||||
callback_done = 0
|
||||
|
||||
async def trigger():
|
||||
while True:
|
||||
yield Message("the latest news about OpenAI")
|
||||
await asyncio.sleep(3600 * 24)
|
||||
|
||||
class MockRole(Role):
|
||||
async def run(self, message=None):
|
||||
return Message("")
|
||||
|
||||
async def callback(message):
|
||||
nonlocal callback_done
|
||||
callback_done += 1
|
||||
|
||||
runner = SubscriptionRunner()
|
||||
|
||||
roles = []
|
||||
for _ in range(2):
|
||||
role = MockRole()
|
||||
roles.append(role)
|
||||
await runner.subscribe(role, trigger(), callback)
|
||||
|
||||
task = asyncio.get_running_loop().create_task(runner.run())
|
||||
|
||||
for _ in range(10):
|
||||
if callback_done == 2:
|
||||
break
|
||||
await asyncio.sleep(0)
|
||||
else:
|
||||
raise TimeoutError("callback not call")
|
||||
|
||||
role = roles[0]
|
||||
assert role in runner.tasks
|
||||
await runner.unsubscribe(roles[0])
|
||||
|
||||
for _ in range(10):
|
||||
if role not in runner.tasks:
|
||||
break
|
||||
await asyncio.sleep(0)
|
||||
else:
|
||||
raise TimeoutError("callback not call")
|
||||
|
||||
task.cancel()
|
||||
for i in runner.tasks.values():
|
||||
i.cancel()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subscription_run_error(loguru_caplog):
|
||||
async def trigger1():
|
||||
while True:
|
||||
yield Message("the latest news about OpenAI")
|
||||
await asyncio.sleep(3600 * 24)
|
||||
|
||||
async def trigger2():
|
||||
yield Message("the latest news about OpenAI")
|
||||
|
||||
class MockRole1(Role):
|
||||
async def run(self, message=None):
|
||||
raise RuntimeError
|
||||
|
||||
class MockRole2(Role):
|
||||
async def run(self, message=None):
|
||||
return Message("")
|
||||
|
||||
async def callback(msg: Message):
|
||||
print(msg)
|
||||
|
||||
runner = SubscriptionRunner()
|
||||
await runner.subscribe(MockRole1(), trigger1(), callback)
|
||||
with pytest.raises(RuntimeError):
|
||||
await runner.run()
|
||||
|
||||
await runner.subscribe(MockRole2(), trigger2(), callback)
|
||||
task = asyncio.get_running_loop().create_task(runner.run(False))
|
||||
|
||||
for _ in range(10):
|
||||
if not runner.tasks:
|
||||
break
|
||||
await asyncio.sleep(0)
|
||||
else:
|
||||
raise TimeoutError("wait runner tasks empty timeout")
|
||||
|
||||
task.cancel()
|
||||
for i in runner.tasks.values():
|
||||
i.cancel()
|
||||
assert len(loguru_caplog.records) >= 2
|
||||
logs = "".join(loguru_caplog.messages)
|
||||
assert "run error" in logs
|
||||
assert "has completed" in logs
|
||||
13
tests/metagpt/test_team.py
Normal file
13
tests/metagpt/test_team.py
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : unittest of team
|
||||
|
||||
from metagpt.roles.project_manager import ProjectManager
|
||||
from metagpt.team import Team
|
||||
|
||||
|
||||
def test_team():
|
||||
company = Team()
|
||||
company.hire([ProjectManager()])
|
||||
|
||||
assert len(company.environment.roles) == 1
|
||||
|
|
@ -1,23 +1,22 @@
|
|||
import pytest
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
|
||||
from tests.data import sales_desc, store_desc
|
||||
from metagpt.tools.code_interpreter import OpenCodeInterpreter, OpenInterpreterDecorator
|
||||
import pandas as pd
|
||||
import pytest
|
||||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.logs import logger
|
||||
from metagpt.tools.code_interpreter import OpenCodeInterpreter, OpenInterpreterDecorator
|
||||
|
||||
|
||||
logger.add('./tests/data/test_ci.log')
|
||||
logger.add("./tests/data/test_ci.log")
|
||||
stock = "./tests/data/baba_stock.csv"
|
||||
|
||||
|
||||
# TODO: 需要一种表格数据格式,能够支持schame管理的,标注字段类型和字段含义。
|
||||
class CreateStockIndicators(Action):
|
||||
@OpenInterpreterDecorator(save_code=True, code_file_path="./tests/data/stock_indicators.py")
|
||||
async def run(self, stock_path: str, indicators=['Simple Moving Average', 'BollingerBands']) -> pd.DataFrame:
|
||||
async def run(self, stock_path: str, indicators=["Simple Moving Average", "BollingerBands"]) -> pd.DataFrame:
|
||||
"""对stock_path中的股票数据, 使用pandas和ta计算indicators中的技术指标, 返回带有技术指标的股票数据,不需要去除空值, 不需要安装任何包;
|
||||
指标生成对应的三列: SMA, BB_upper, BB_lower
|
||||
指标生成对应的三列: SMA, BB_upper, BB_lower
|
||||
"""
|
||||
...
|
||||
|
||||
|
|
@ -25,18 +24,20 @@ class CreateStockIndicators(Action):
|
|||
@pytest.mark.asyncio
|
||||
async def test_actions():
|
||||
# 计算指标
|
||||
indicators = ['Simple Moving Average', 'BollingerBands']
|
||||
indicators = ["Simple Moving Average", "BollingerBands"]
|
||||
stocker = CreateStockIndicators()
|
||||
df, msg = await stocker.run(stock, indicators=indicators)
|
||||
assert isinstance(df, pd.DataFrame)
|
||||
assert 'Close' in df.columns
|
||||
assert 'Date' in df.columns
|
||||
assert "Close" in df.columns
|
||||
assert "Date" in df.columns
|
||||
# 将df保存为文件,将文件路径传入到下一个action
|
||||
df_path = './tests/data/stock_indicators.csv'
|
||||
df_path = "./tests/data/stock_indicators.csv"
|
||||
df.to_csv(df_path)
|
||||
assert Path(df_path).is_file()
|
||||
# 可视化指标结果
|
||||
figure_path = './tests/data/figure_ci.png'
|
||||
figure_path = "./tests/data/figure_ci.png"
|
||||
ci_ploter = OpenCodeInterpreter()
|
||||
ci_ploter.chat(f"使用seaborn对{df_path}中与股票布林带有关的数据列的Date, Close, SMA, BB_upper(布林带上界), BB_lower(布林带下界)进行可视化, 可视化图片保存在{figure_path}中。不需要任何指标计算,把Date列转换为日期类型。要求图片优美,BB_upper, BB_lower之间使用合适的颜色填充。")
|
||||
ci_ploter.chat(
|
||||
f"使用seaborn对{df_path}中与股票布林带有关的数据列的Date, Close, SMA, BB_upper(布林带上界), BB_lower(布林带下界)进行可视化, 可视化图片保存在{figure_path}中。不需要任何指标计算,把Date列转换为日期类型。要求图片优美,BB_upper, BB_lower之间使用合适的颜色填充。"
|
||||
)
|
||||
assert Path(figure_path).is_file()
|
||||
|
|
|
|||
|
|
@ -20,8 +20,9 @@ from metagpt.tools.prompt_writer import (
|
|||
@pytest.mark.usefixtures("llm_api")
|
||||
def test_gpt_prompt_generator(llm_api):
|
||||
generator = GPTPromptGenerator()
|
||||
example = "商品名称:WonderLab 新肌果味代餐奶昔 小胖瓶 胶原蛋白升级版 饱腹代餐粉6瓶 75g/瓶(6瓶/盒) 店铺名称:金力宁食品专营店 " \
|
||||
"品牌:WonderLab 保质期:1年 产地:中国 净含量:450g"
|
||||
example = (
|
||||
"商品名称:WonderLab 新肌果味代餐奶昔 小胖瓶 胶原蛋白升级版 饱腹代餐粉6瓶 75g/瓶(6瓶/盒) 店铺名称:金力宁食品专营店 " "品牌:WonderLab 保质期:1年 产地:中国 净含量:450g"
|
||||
)
|
||||
|
||||
results = llm_api.ask_batch(generator.gen(example))
|
||||
logger.info(results)
|
||||
|
|
@ -46,7 +47,7 @@ def test_enron_template(llm_api):
|
|||
|
||||
results = template.gen(subj)
|
||||
assert len(results) > 0
|
||||
assert any("Write an email with the subject \"Meeting Agenda\"." in r for r in results)
|
||||
assert any('Write an email with the subject "Meeting Agenda".' in r for r in results)
|
||||
|
||||
|
||||
def test_beagec_template():
|
||||
|
|
@ -54,5 +55,6 @@ def test_beagec_template():
|
|||
|
||||
results = template.gen()
|
||||
assert len(results) > 0
|
||||
assert any("Edit and revise this document to improve its grammar, vocabulary, spelling, and style."
|
||||
in r for r in results)
|
||||
assert any(
|
||||
"Edit and revise this document to improve its grammar, vocabulary, spelling, and style." in r for r in results
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,10 +1,11 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Date : 2023/7/22 02:40
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Author : stellahong (stellahong@deepwisdom.ai)
|
||||
#
|
||||
import os
|
||||
|
||||
from metagpt.tools.sd_engine import SDEngine, WORKSPACE_ROOT
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.tools.sd_engine import SDEngine
|
||||
|
||||
|
||||
def test_sd_engine_init():
|
||||
|
|
@ -21,5 +22,5 @@ def test_sd_engine_generate_prompt():
|
|||
async def test_sd_engine_run_t2i():
|
||||
sd_engine = SDEngine()
|
||||
await sd_engine.run_t2i(prompts=["test"])
|
||||
img_path = WORKSPACE_ROOT / "resources" / "SD_Output" / "output_0.png"
|
||||
assert os.path.exists(img_path) == True
|
||||
img_path = CONFIG.workspace_path / "resources" / "SD_Output" / "output_0.png"
|
||||
assert os.path.exists(img_path)
|
||||
|
|
|
|||
|
|
@ -16,7 +16,9 @@ from metagpt.tools.search_engine import SearchEngine
|
|||
|
||||
class MockSearchEnine:
|
||||
async def run(self, query: str, max_results: int = 8, as_string: bool = True) -> str | list[dict[str, str]]:
|
||||
rets = [{"url": "https://metagpt.com/mock/{i}", "title": query, "snippet": query * i} for i in range(max_results)]
|
||||
rets = [
|
||||
{"url": "https://metagpt.com/mock/{i}", "title": query, "snippet": query * i} for i in range(max_results)
|
||||
]
|
||||
return "\n".join(rets) if as_string else rets
|
||||
|
||||
|
||||
|
|
@ -34,10 +36,14 @@ class MockSearchEnine:
|
|||
(SearchEngineType.DUCK_DUCK_GO, None, 6, False),
|
||||
(SearchEngineType.CUSTOM_ENGINE, MockSearchEnine().run, 8, False),
|
||||
(SearchEngineType.CUSTOM_ENGINE, MockSearchEnine().run, 6, False),
|
||||
|
||||
],
|
||||
)
|
||||
async def test_search_engine(search_engine_typpe, run_func, max_results, as_string, ):
|
||||
async def test_search_engine(
|
||||
search_engine_typpe,
|
||||
run_func,
|
||||
max_results,
|
||||
as_string,
|
||||
):
|
||||
search_engine = SearchEngine(search_engine_typpe, run_func)
|
||||
rsp = await search_engine.run("metagpt", max_results=max_results, as_string=as_string)
|
||||
logger.info(rsp)
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ import pytest
|
|||
from metagpt.logs import logger
|
||||
from metagpt.tools.search_engine_meilisearch import DataSource, MeilisearchEngine
|
||||
|
||||
MASTER_KEY = '116Qavl2qpCYNEJNv5-e0RC9kncev1nr1gt7ybEGVLk'
|
||||
MASTER_KEY = "116Qavl2qpCYNEJNv5-e0RC9kncev1nr1gt7ybEGVLk"
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
|
|
@ -29,7 +29,7 @@ def test_meilisearch(search_engine_server):
|
|||
search_engine = MeilisearchEngine(url="http://localhost:7700", token=MASTER_KEY)
|
||||
|
||||
# 假设有一个名为"books"的数据源,包含要添加的文档库
|
||||
books_data_source = DataSource(name='books', url='https://example.com/books')
|
||||
books_data_source = DataSource(name="books", url="https://example.com/books")
|
||||
|
||||
# 假设有一个名为"documents"的文档库,包含要添加的文档
|
||||
documents = [
|
||||
|
|
@ -43,4 +43,4 @@ def test_meilisearch(search_engine_server):
|
|||
|
||||
# 添加文档库到搜索引擎
|
||||
search_engine.add_documents(books_data_source, documents)
|
||||
logger.info(search_engine.search('Book 1'))
|
||||
logger.info(search_engine.search("Book 1"))
|
||||
|
|
|
|||
File diff suppressed because one or more lines are too long
|
|
@ -16,7 +16,7 @@ from metagpt.tools.translator import Translator
|
|||
def test_translate(llm_api):
|
||||
poetries = [
|
||||
("Let life be beautiful like summer flowers", "花"),
|
||||
("The ancient Chinese poetries are all songs.", "中国")
|
||||
("The ancient Chinese poetries are all songs.", "中国"),
|
||||
]
|
||||
for i, j in poetries:
|
||||
prompt = Translator.translate_prompt(i)
|
||||
|
|
|
|||
|
|
@ -16,8 +16,12 @@ class TestUTWriter:
|
|||
tags = ["测试"] # "智能合同导入", "律师审查", "ai合同审查", "草拟合同&律师在线审查", "合同审批", "履约管理", "签约公司"]
|
||||
# 这里在文件中手动加入了两个测试标签的API
|
||||
|
||||
utg = UTGenerator(swagger_file=swagger_file, ut_py_path=UT_PY_PATH, questions_path=API_QUESTIONS_PATH,
|
||||
template_prefix=YFT_PROMPT_PREFIX)
|
||||
utg = UTGenerator(
|
||||
swagger_file=swagger_file,
|
||||
ut_py_path=UT_PY_PATH,
|
||||
questions_path=API_QUESTIONS_PATH,
|
||||
template_prefix=YFT_PROMPT_PREFIX,
|
||||
)
|
||||
ret = utg.generate_ut(include_tags=tags)
|
||||
# 后续加入对文件生成内容与数量的检验
|
||||
assert ret
|
||||
|
|
|
|||
|
|
@ -7,8 +7,8 @@ from metagpt.tools import WebBrowserEngineType, web_browser_engine
|
|||
@pytest.mark.parametrize(
|
||||
"browser_type, url, urls",
|
||||
[
|
||||
(WebBrowserEngineType.PLAYWRIGHT, "https://fuzhi.ai", ("https://fuzhi.ai",)),
|
||||
(WebBrowserEngineType.SELENIUM, "https://fuzhi.ai", ("https://fuzhi.ai",)),
|
||||
(WebBrowserEngineType.PLAYWRIGHT, "https://deepwisdom.ai", ("https://deepwisdom.ai",)),
|
||||
(WebBrowserEngineType.SELENIUM, "https://deepwisdom.ai", ("https://deepwisdom.ai",)),
|
||||
],
|
||||
ids=["playwright", "selenium"],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -8,9 +8,9 @@ from metagpt.tools import web_browser_engine_playwright
|
|||
@pytest.mark.parametrize(
|
||||
"browser_type, use_proxy, kwagrs, url, urls",
|
||||
[
|
||||
("chromium", {"proxy": True}, {}, "https://fuzhi.ai", ("https://fuzhi.ai",)),
|
||||
("firefox", {}, {"ignore_https_errors": True}, "https://fuzhi.ai", ("https://fuzhi.ai",)),
|
||||
("webkit", {}, {"ignore_https_errors": True}, "https://fuzhi.ai", ("https://fuzhi.ai",)),
|
||||
("chromium", {"proxy": True}, {}, "https://deepwisdom.ai", ("https://deepwisdom.ai",)),
|
||||
("firefox", {}, {"ignore_https_errors": True}, "https://deepwisdom.ai", ("https://deepwisdom.ai",)),
|
||||
("webkit", {}, {"ignore_https_errors": True}, "https://deepwisdom.ai", ("https://deepwisdom.ai",)),
|
||||
],
|
||||
ids=["chromium-normal", "firefox-normal", "webkit-normal"],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -8,9 +8,9 @@ from metagpt.tools import web_browser_engine_selenium
|
|||
@pytest.mark.parametrize(
|
||||
"browser_type, use_proxy, url, urls",
|
||||
[
|
||||
("chrome", True, "https://fuzhi.ai", ("https://fuzhi.ai",)),
|
||||
("firefox", False, "https://fuzhi.ai", ("https://fuzhi.ai",)),
|
||||
("edge", False, "https://fuzhi.ai", ("https://fuzhi.ai",)),
|
||||
("chrome", True, "https://deepwisdom.ai", ("https://deepwisdom.ai",)),
|
||||
("firefox", False, "https://deepwisdom.ai", ("https://deepwisdom.ai",)),
|
||||
("edge", False, "https://deepwisdom.ai", ("https://deepwisdom.ai",)),
|
||||
],
|
||||
ids=["chrome-normal", "firefox-normal", "edge-normal"],
|
||||
)
|
||||
|
|
|
|||
29
tests/metagpt/utils/test_ahttp_client.py
Normal file
29
tests/metagpt/utils/test_ahttp_client.py
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : unittest of ahttp_client
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.utils.ahttp_client import apost, apost_stream
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_apost():
|
||||
result = await apost(url="https://www.baidu.com/")
|
||||
assert "百度一下" in result
|
||||
|
||||
result = await apost(
|
||||
url="http://aider.meizu.com/app/weather/listWeather", data={"cityIds": "101240101"}, as_json=True
|
||||
)
|
||||
assert result["code"] == "200"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_apost_stream():
|
||||
result = apost_stream(url="https://www.baidu.com/")
|
||||
async for line in result:
|
||||
assert len(line) >= 0
|
||||
|
||||
result = apost_stream(url="http://aider.meizu.com/app/weather/listWeather", data={"cityIds": "101240101"})
|
||||
async for line in result:
|
||||
assert len(line) >= 0
|
||||
|
|
@ -131,10 +131,10 @@ class TestCodeParser:
|
|||
def test_parse_file_list(self, parser, text):
|
||||
result = parser.parse_file_list("Task list", text)
|
||||
print(result)
|
||||
assert result == ['task1', 'task2']
|
||||
assert result == ["task1", "task2"]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
t = TestCodeParser()
|
||||
t.test_parse_file_list(CodeParser(), t_text)
|
||||
# TestCodeParser.test_parse_file_list()
|
||||
|
|
|
|||
|
|
@ -4,27 +4,79 @@
|
|||
@Time : 2023/4/29 16:19
|
||||
@Author : alexanderwu
|
||||
@File : test_common.py
|
||||
@Modified by: mashenquan, 2023/11/21. Add unit tests.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Any, Set
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from metagpt.const import get_project_root
|
||||
from metagpt.actions import RunCode
|
||||
from metagpt.const import get_metagpt_root
|
||||
from metagpt.roles.tutorial_assistant import TutorialAssistant
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.common import any_to_str, any_to_str_set
|
||||
|
||||
|
||||
class TestGetProjectRoot:
|
||||
def change_etc_dir(self):
|
||||
# current_directory = Path.cwd()
|
||||
abs_root = '/etc'
|
||||
abs_root = "/etc"
|
||||
os.chdir(abs_root)
|
||||
|
||||
def test_get_project_root(self):
|
||||
project_root = get_project_root()
|
||||
assert project_root.name == 'metagpt'
|
||||
project_root = get_metagpt_root()
|
||||
assert project_root.name == "metagpt"
|
||||
|
||||
def test_get_root_exception(self):
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
self.change_etc_dir()
|
||||
get_project_root()
|
||||
get_metagpt_root()
|
||||
assert str(exc_info.value) == "Project root not found."
|
||||
|
||||
def test_any_to_str(self):
|
||||
class Input(BaseModel):
|
||||
x: Any
|
||||
want: str
|
||||
|
||||
inputs = [
|
||||
Input(x=TutorialAssistant, want="metagpt.roles.tutorial_assistant.TutorialAssistant"),
|
||||
Input(x=TutorialAssistant(), want="metagpt.roles.tutorial_assistant.TutorialAssistant"),
|
||||
Input(x=RunCode, want="metagpt.actions.run_code.RunCode"),
|
||||
Input(x=RunCode(), want="metagpt.actions.run_code.RunCode"),
|
||||
Input(x=Message, want="metagpt.schema.Message"),
|
||||
Input(x=Message(""), want="metagpt.schema.Message"),
|
||||
Input(x="A", want="A"),
|
||||
]
|
||||
for i in inputs:
|
||||
v = any_to_str(i.x)
|
||||
assert v == i.want
|
||||
|
||||
def test_any_to_str_set(self):
|
||||
class Input(BaseModel):
|
||||
x: Any
|
||||
want: Set
|
||||
|
||||
inputs = [
|
||||
Input(
|
||||
x=[TutorialAssistant, RunCode(), "a"],
|
||||
want={"metagpt.roles.tutorial_assistant.TutorialAssistant", "metagpt.actions.run_code.RunCode", "a"},
|
||||
),
|
||||
Input(
|
||||
x={TutorialAssistant, RunCode(), "a"},
|
||||
want={"metagpt.roles.tutorial_assistant.TutorialAssistant", "metagpt.actions.run_code.RunCode", "a"},
|
||||
),
|
||||
Input(
|
||||
x=(TutorialAssistant, RunCode(), "a"),
|
||||
want={"metagpt.roles.tutorial_assistant.TutorialAssistant", "metagpt.actions.run_code.RunCode", "a"},
|
||||
),
|
||||
]
|
||||
for i in inputs:
|
||||
v = any_to_str_set(i.x)
|
||||
assert v == i.want
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
|
|
@ -20,12 +20,12 @@ def test_config_class_is_singleton():
|
|||
def test_config_class_get_key_exception():
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
config = Config()
|
||||
config.get('wtf')
|
||||
config.get("wtf")
|
||||
assert str(exc_info.value) == "Key 'wtf' not found in environment variables or in the YAML file"
|
||||
|
||||
|
||||
def test_config_yaml_file_not_exists():
|
||||
config = Config('wtf.yaml')
|
||||
config = Config("wtf.yaml")
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
config.get('OPENAI_BASE_URL')
|
||||
config.get("OPENAI_BASE_URL")
|
||||
assert str(exc_info.value) == "Key 'OPENAI_BASE_URL' not found in environment variables or in the YAML file"
|
||||
|
|
|
|||
|
|
@ -10,12 +10,12 @@ from metagpt.provider.openai_api import OpenAIGPTAPI
|
|||
|
||||
|
||||
async def try_hello(api):
|
||||
batch = [[{'role': 'user', 'content': 'hello'}]]
|
||||
batch = [[{"role": "user", "content": "hello"}]]
|
||||
results = await api.acompletion_batch_text(batch)
|
||||
return results
|
||||
|
||||
|
||||
async def aask_batch(api: OpenAIGPTAPI):
|
||||
results = await api.aask_batch(['hi', 'write python hello world.'])
|
||||
results = await api.aask_batch(["hi", "write python hello world."])
|
||||
logger.info(results)
|
||||
return results
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@
|
|||
@File : test_custom_decoder.py
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.utils.custom_decoder import CustomDecoder
|
||||
|
||||
|
|
@ -37,6 +38,46 @@ def test_parse_single_quote():
|
|||
parsed_data = decoder.decode(input_data)
|
||||
assert 'a"\n b' in parsed_data
|
||||
|
||||
input_data = """{
|
||||
'a': "
|
||||
b
|
||||
"
|
||||
}
|
||||
"""
|
||||
with pytest.raises(Exception):
|
||||
parsed_data = decoder.decode(input_data)
|
||||
|
||||
input_data = """{
|
||||
'a': '
|
||||
b
|
||||
'
|
||||
}
|
||||
"""
|
||||
with pytest.raises(Exception):
|
||||
parsed_data = decoder.decode(input_data)
|
||||
|
||||
|
||||
def test_parse_double_quote():
|
||||
decoder = CustomDecoder(strict=False)
|
||||
|
||||
input_data = """{
|
||||
"a": "
|
||||
b
|
||||
"
|
||||
}
|
||||
"""
|
||||
parsed_data = decoder.decode(input_data)
|
||||
assert parsed_data["a"] == "\n b\n"
|
||||
|
||||
input_data = """{
|
||||
"a": '
|
||||
b
|
||||
'
|
||||
}
|
||||
"""
|
||||
parsed_data = decoder.decode(input_data)
|
||||
assert parsed_data["a"] == "\n b\n"
|
||||
|
||||
|
||||
def test_parse_triple_double_quote():
|
||||
# Create a custom JSON decoder
|
||||
|
|
@ -54,6 +95,10 @@ def test_parse_triple_double_quote():
|
|||
parsed_data = decoder.decode(input_data)
|
||||
assert parsed_data["a"] == "b"
|
||||
|
||||
input_data = "{\"\"\"a\"\"\": '''b'''}"
|
||||
parsed_data = decoder.decode(input_data)
|
||||
assert parsed_data["a"] == "b"
|
||||
|
||||
|
||||
def test_parse_triple_single_quote():
|
||||
# Create a custom JSON decoder
|
||||
|
|
|
|||
64
tests/metagpt/utils/test_dependency_file.py
Normal file
64
tests/metagpt/utils/test_dependency_file.py
Normal file
|
|
@ -0,0 +1,64 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/11/22
|
||||
@Author : mashenquan
|
||||
@File : test_dependency_file.py
|
||||
@Desc: Unit tests for dependency_file.py
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional, Set, Union
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from metagpt.utils.dependency_file import DependencyFile
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dependency_file():
|
||||
class Input(BaseModel):
|
||||
x: Union[Path, str]
|
||||
deps: Optional[Set[Union[Path, str]]]
|
||||
key: Optional[Union[Path, str]]
|
||||
want: Set[str]
|
||||
|
||||
inputs = [
|
||||
Input(x="a/b.txt", deps={"c/e.txt", Path(__file__).parent / "d.txt"}, want={"c/e.txt", "d.txt"}),
|
||||
Input(
|
||||
x=Path(__file__).parent / "x/b.txt",
|
||||
deps={"s/e.txt", Path(__file__).parent / "d.txt"},
|
||||
key="x/b.txt",
|
||||
want={"s/e.txt", "d.txt"},
|
||||
),
|
||||
Input(x="f.txt", deps=None, want=set()),
|
||||
Input(x="a/b.txt", deps=None, want=set()),
|
||||
]
|
||||
|
||||
file = DependencyFile(workdir=Path(__file__).parent)
|
||||
|
||||
for i in inputs:
|
||||
await file.update(filename=i.x, dependencies=i.deps)
|
||||
assert await file.get(filename=i.key or i.x) == i.want
|
||||
|
||||
file2 = DependencyFile(workdir=Path(__file__).parent)
|
||||
file2.delete_file()
|
||||
assert not file.exists
|
||||
await file2.update(filename="a/b.txt", dependencies={"c/e.txt", Path(__file__).parent / "d.txt"}, persist=False)
|
||||
assert not file.exists
|
||||
await file2.save()
|
||||
assert file2.exists
|
||||
|
||||
file1 = DependencyFile(workdir=Path(__file__).parent)
|
||||
assert file1.exists
|
||||
assert await file1.get("a/b.txt") == set()
|
||||
await file1.load()
|
||||
assert await file1.get("a/b.txt") == {"c/e.txt", "d.txt"}
|
||||
file1.delete_file()
|
||||
assert not file.exists
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
@ -15,12 +15,11 @@ from metagpt.utils.file import File
|
|||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
("root_path", "filename", "content"),
|
||||
[(Path("/code/MetaGPT/data/tutorial_docx/2023-09-07_17-05-20"), "test.md", "Hello World!")]
|
||||
[(Path("/code/MetaGPT/data/tutorial_docx/2023-09-07_17-05-20"), "test.md", "Hello World!")],
|
||||
)
|
||||
async def test_write_and_read_file(root_path: Path, filename: str, content: bytes):
|
||||
full_file_name = await File.write(root_path=root_path, filename=filename, content=content.encode('utf-8'))
|
||||
full_file_name = await File.write(root_path=root_path, filename=filename, content=content.encode("utf-8"))
|
||||
assert isinstance(full_file_name, Path)
|
||||
assert root_path / filename == full_file_name
|
||||
file_data = await File.read(full_file_name)
|
||||
assert file_data.decode("utf-8") == content
|
||||
|
||||
|
|
|
|||
55
tests/metagpt/utils/test_file_repository.py
Normal file
55
tests/metagpt/utils/test_file_repository.py
Normal file
|
|
@ -0,0 +1,55 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/11/20
|
||||
@Author : mashenquan
|
||||
@File : test_file_repository.py
|
||||
@Desc: Unit tests for file_repository.py
|
||||
"""
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.utils.git_repository import ChangeType, GitRepository
|
||||
from tests.metagpt.utils.test_git_repository import mock_file
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_file_repo():
|
||||
local_path = Path(__file__).parent / "file_repo_git"
|
||||
if local_path.exists():
|
||||
shutil.rmtree(local_path)
|
||||
|
||||
git_repo = GitRepository(local_path=local_path, auto_init=True)
|
||||
assert not git_repo.changed_files
|
||||
|
||||
await mock_file(local_path / "g.txt", "")
|
||||
|
||||
file_repo_path = "file_repo1"
|
||||
full_path = local_path / file_repo_path
|
||||
assert not full_path.exists()
|
||||
file_repo = git_repo.new_file_repository(file_repo_path)
|
||||
assert file_repo.workdir == full_path
|
||||
assert file_repo.workdir.exists()
|
||||
await file_repo.save("a.txt", "AAA")
|
||||
await file_repo.save("b.txt", "BBB", ["a.txt"])
|
||||
doc = await file_repo.get("a.txt")
|
||||
assert "AAA" == doc.content
|
||||
doc = await file_repo.get("b.txt")
|
||||
assert "BBB" == doc.content
|
||||
assert {"a.txt"} == await file_repo.get_dependency("b.txt")
|
||||
assert {"a.txt": ChangeType.UNTRACTED, "b.txt": ChangeType.UNTRACTED} == file_repo.changed_files
|
||||
assert {"a.txt"} == await file_repo.get_changed_dependency("b.txt")
|
||||
await file_repo.save("d/e.txt", "EEE")
|
||||
assert ["d/e.txt"] == file_repo.get_change_dir_files("d")
|
||||
assert set(file_repo.all_files) == {"a.txt", "b.txt", "d/e.txt"}
|
||||
await file_repo.delete("d/e.txt")
|
||||
await file_repo.delete("d/e.txt") # delete twice
|
||||
assert set(file_repo.all_files) == {"a.txt", "b.txt"}
|
||||
|
||||
git_repo.delete_repository()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
103
tests/metagpt/utils/test_git_repository.py
Normal file
103
tests/metagpt/utils/test_git_repository.py
Normal file
|
|
@ -0,0 +1,103 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/11/20
|
||||
@Author : mashenquan
|
||||
@File : test_git_repository.py
|
||||
@Desc: Unit tests for git_repository.py
|
||||
"""
|
||||
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import aiofiles
|
||||
import pytest
|
||||
|
||||
from metagpt.utils.git_repository import GitRepository
|
||||
|
||||
|
||||
async def mock_file(filename, content=""):
|
||||
async with aiofiles.open(str(filename), mode="w") as file:
|
||||
await file.write(content)
|
||||
|
||||
|
||||
async def mock_repo(local_path) -> (GitRepository, Path):
|
||||
if local_path.exists():
|
||||
shutil.rmtree(local_path)
|
||||
assert not local_path.exists()
|
||||
repo = GitRepository(local_path=local_path, auto_init=True)
|
||||
assert local_path.exists()
|
||||
assert local_path == repo.workdir
|
||||
assert not repo.changed_files
|
||||
|
||||
await mock_file(local_path / "a.txt")
|
||||
await mock_file(local_path / "b.txt")
|
||||
subdir = local_path / "subdir"
|
||||
subdir.mkdir(parents=True, exist_ok=True)
|
||||
await mock_file(subdir / "c.txt")
|
||||
return repo, subdir
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_git():
|
||||
local_path = Path(__file__).parent / "git"
|
||||
repo, subdir = await mock_repo(local_path)
|
||||
|
||||
assert len(repo.changed_files) == 3
|
||||
repo.add_change(repo.changed_files)
|
||||
repo.commit("commit1")
|
||||
assert not repo.changed_files
|
||||
|
||||
await mock_file(local_path / "a.txt", "tests")
|
||||
await mock_file(subdir / "d.txt")
|
||||
rmfile = local_path / "b.txt"
|
||||
rmfile.unlink()
|
||||
assert repo.status
|
||||
|
||||
assert len(repo.changed_files) == 3
|
||||
repo.add_change(repo.changed_files)
|
||||
repo.commit("commit2")
|
||||
assert not repo.changed_files
|
||||
|
||||
assert repo.status
|
||||
|
||||
repo.delete_repository()
|
||||
assert not local_path.exists()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_git1():
|
||||
local_path = Path(__file__).parent / "git1"
|
||||
await mock_repo(local_path)
|
||||
|
||||
repo1 = GitRepository(local_path=local_path, auto_init=False)
|
||||
assert repo1.changed_files
|
||||
|
||||
file_repo = repo1.new_file_repository("__pycache__")
|
||||
await file_repo.save("a.pyc", content="")
|
||||
all_files = repo1.get_files(relative_path=".", filter_ignored=False)
|
||||
assert "__pycache__/a.pyc" in all_files
|
||||
all_files = repo1.get_files(relative_path=".", filter_ignored=True)
|
||||
assert "__pycache__/a.pyc" not in all_files
|
||||
|
||||
repo1.delete_repository()
|
||||
assert not local_path.exists()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dependency_file():
|
||||
local_path = Path(__file__).parent / "git2"
|
||||
repo, subdir = await mock_repo(local_path)
|
||||
|
||||
dependancy_file = await repo.get_dependency()
|
||||
assert not dependancy_file.exists
|
||||
|
||||
await dependancy_file.update(filename="a/b.txt", dependencies={"c/d.txt", "e/f.txt"})
|
||||
assert dependancy_file.exists
|
||||
|
||||
repo.delete_repository()
|
||||
assert not dependancy_file.exists
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
@ -14,17 +14,17 @@ from metagpt.utils.common import OutputParser
|
|||
|
||||
def test_parse_blocks():
|
||||
test_text = "##block1\nThis is block 1.\n##block2\nThis is block 2."
|
||||
expected_result = {'block1': 'This is block 1.', 'block2': 'This is block 2.'}
|
||||
expected_result = {"block1": "This is block 1.", "block2": "This is block 2."}
|
||||
assert OutputParser.parse_blocks(test_text) == expected_result
|
||||
|
||||
|
||||
def test_parse_code():
|
||||
test_text = "```python\nprint('Hello, world!')```"
|
||||
expected_result = "print('Hello, world!')"
|
||||
assert OutputParser.parse_code(test_text, 'python') == expected_result
|
||||
assert OutputParser.parse_code(test_text, "python") == expected_result
|
||||
|
||||
with pytest.raises(Exception):
|
||||
OutputParser.parse_code(test_text, 'java')
|
||||
OutputParser.parse_code(test_text, "java")
|
||||
|
||||
|
||||
def test_parse_python_code():
|
||||
|
|
@ -45,13 +45,13 @@ def test_parse_python_code():
|
|||
|
||||
def test_parse_str():
|
||||
test_text = "name = 'Alice'"
|
||||
expected_result = 'Alice'
|
||||
expected_result = "Alice"
|
||||
assert OutputParser.parse_str(test_text) == expected_result
|
||||
|
||||
|
||||
def test_parse_file_list():
|
||||
test_text = "files=['file1', 'file2', 'file3']"
|
||||
expected_result = ['file1', 'file2', 'file3']
|
||||
expected_result = ["file1", "file2", "file3"]
|
||||
assert OutputParser.parse_file_list(test_text) == expected_result
|
||||
|
||||
with pytest.raises(Exception):
|
||||
|
|
@ -60,7 +60,7 @@ def test_parse_file_list():
|
|||
|
||||
def test_parse_data():
|
||||
test_data = "##block1\n```python\nprint('Hello, world!')\n```\n##block2\nfiles=['file1', 'file2', 'file3']"
|
||||
expected_result = {'block1': "print('Hello, world!')", 'block2': ['file1', 'file2', 'file3']}
|
||||
expected_result = {"block1": "print('Hello, world!')", "block2": ["file1", "file2", "file3"]}
|
||||
assert OutputParser.parse_data(test_data) == expected_result
|
||||
|
||||
|
||||
|
|
@ -103,9 +103,11 @@ def test_parse_data():
|
|||
None,
|
||||
Exception,
|
||||
),
|
||||
]
|
||||
],
|
||||
)
|
||||
def test_extract_struct(text: str, data_type: Union[type(list), type(dict)], parsed_data: Union[list, dict], expected_exception):
|
||||
def test_extract_struct(
|
||||
text: str, data_type: Union[type(list), type(dict)], parsed_data: Union[list, dict], expected_exception
|
||||
):
|
||||
def case():
|
||||
resp = OutputParser.extract_struct(text, data_type)
|
||||
assert resp == parsed_data
|
||||
|
|
@ -117,7 +119,7 @@ def test_extract_struct(text: str, data_type: Union[type(list), type(dict)], par
|
|||
case()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
t_text = '''
|
||||
## Required Python third-party packages
|
||||
```python
|
||||
|
|
@ -216,9 +218,9 @@ We need clarification on how the high score should be stored. Should it persist
|
|||
"Requirement Pool": (List[Tuple[str, str]], ...),
|
||||
"Anything UNCLEAR": (str, ...),
|
||||
}
|
||||
t_text1 = '''## Original Requirements:
|
||||
t_text1 = """## Original Requirements:
|
||||
|
||||
The boss wants to create a web-based version of the game "Fly Bird".
|
||||
The user wants to create a web-based version of the game "Fly Bird".
|
||||
|
||||
## Product Goals:
|
||||
|
||||
|
|
@ -284,7 +286,7 @@ The product should be a web-based version of the game "Fly Bird" that is engagin
|
|||
## Anything UNCLEAR:
|
||||
|
||||
There are no unclear points.
|
||||
'''
|
||||
"""
|
||||
d = OutputParser.parse_data_with_mapping(t_text1, OUTPUT_MAPPING)
|
||||
import json
|
||||
|
||||
|
|
|
|||
|
|
@ -52,9 +52,11 @@ PAGE = """
|
|||
</html>
|
||||
"""
|
||||
|
||||
CONTENT = 'This is a HeadingThis is a paragraph witha linkand someemphasizedtext.Item 1Item 2Item 3Numbered Item 1Numbered '\
|
||||
'Item 2Numbered Item 3Header 1Header 2Row 1, Cell 1Row 1, Cell 2Row 2, Cell 1Row 2, Cell 2Name:Email:SubmitThis is a div '\
|
||||
'with a class "box".a link'
|
||||
CONTENT = (
|
||||
"This is a HeadingThis is a paragraph witha linkand someemphasizedtext.Item 1Item 2Item 3Numbered Item 1Numbered "
|
||||
"Item 2Numbered Item 3Header 1Header 2Row 1, Cell 1Row 1, Cell 2Row 2, Cell 1Row 2, Cell 2Name:Email:SubmitThis is a div "
|
||||
'with a class "box".a link'
|
||||
)
|
||||
|
||||
|
||||
def test_web_page():
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from metagpt.utils import pycst
|
||||
|
||||
code = '''
|
||||
code = """
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
from typing import overload
|
||||
|
|
@ -24,7 +24,7 @@ class Person:
|
|||
|
||||
def greet(self):
|
||||
return f"Hello, my name is {self.name} and I am {self.age} years old."
|
||||
'''
|
||||
"""
|
||||
|
||||
documented_code = '''
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -6,12 +6,12 @@
|
|||
@File : test_read_docx.py
|
||||
"""
|
||||
|
||||
from metagpt.const import PROJECT_ROOT
|
||||
from metagpt.const import METAGPT_ROOT
|
||||
from metagpt.utils.read_document import read_docx
|
||||
|
||||
|
||||
class TestReadDocx:
|
||||
def test_read_docx(self):
|
||||
docx_sample = PROJECT_ROOT / "tests/data/docx_for_test.docx"
|
||||
docx_sample = METAGPT_ROOT / "tests/data/docx_for_test.docx"
|
||||
docx = read_docx(docx_sample)
|
||||
assert len(docx) == 6
|
||||
|
|
|
|||
317
tests/metagpt/utils/test_repair_llm_raw_output.py
Normal file
317
tests/metagpt/utils/test_repair_llm_raw_output.py
Normal file
|
|
@ -0,0 +1,317 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : unittest of repair_llm_raw_output
|
||||
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.utils.repair_llm_raw_output import (
|
||||
RepairType,
|
||||
extract_content_from_output,
|
||||
repair_invalid_json,
|
||||
repair_llm_raw_output,
|
||||
retry_parse_json_text,
|
||||
)
|
||||
|
||||
CONFIG.repair_llm_output = True
|
||||
|
||||
|
||||
def test_repair_case_sensitivity():
|
||||
raw_output = """{
|
||||
"Original requirements": "Write a 2048 game",
|
||||
"search Information": "",
|
||||
"competitive Quadrant charT": "quadrantChart
|
||||
Campaign A: [0.3, 0.6]",
|
||||
"requirement analysis": "The 2048 game should be simple to play"
|
||||
}"""
|
||||
target_output = """{
|
||||
"Original Requirements": "Write a 2048 game",
|
||||
"Search Information": "",
|
||||
"Competitive Quadrant Chart": "quadrantChart
|
||||
Campaign A: [0.3, 0.6]",
|
||||
"Requirement Analysis": "The 2048 game should be simple to play"
|
||||
}"""
|
||||
req_keys = ["Original Requirements", "Search Information", "Competitive Quadrant Chart", "Requirement Analysis"]
|
||||
output = repair_llm_raw_output(output=raw_output, req_keys=req_keys)
|
||||
assert output == target_output
|
||||
|
||||
|
||||
def test_repair_special_character_missing():
|
||||
raw_output = """[CONTENT]
|
||||
"Anything UNCLEAR": "No unclear requirements or information."
|
||||
[CONTENT]"""
|
||||
|
||||
target_output = """[CONTENT]
|
||||
"Anything UNCLEAR": "No unclear requirements or information."
|
||||
[/CONTENT]"""
|
||||
req_keys = ["[/CONTENT]"]
|
||||
output = repair_llm_raw_output(output=raw_output, req_keys=req_keys)
|
||||
assert output == target_output
|
||||
|
||||
raw_output = """[CONTENT] tag
|
||||
[CONTENT]
|
||||
{
|
||||
"Anything UNCLEAR": "No unclear requirements or information."
|
||||
}
|
||||
[CONTENT]"""
|
||||
target_output = """[CONTENT] tag
|
||||
[CONTENT]
|
||||
{
|
||||
"Anything UNCLEAR": "No unclear requirements or information."
|
||||
}
|
||||
[/CONTENT]"""
|
||||
output = repair_llm_raw_output(output=raw_output, req_keys=req_keys)
|
||||
assert output == target_output
|
||||
|
||||
raw_output = '[CONTENT] {"a": "b"} [CONTENT]'
|
||||
target_output = '[CONTENT] {"a": "b"} [/CONTENT]'
|
||||
|
||||
output = repair_llm_raw_output(output=raw_output, req_keys=["[/CONTENT]"])
|
||||
print("output\n", output)
|
||||
assert output == target_output
|
||||
|
||||
|
||||
def test_required_key_pair_missing():
|
||||
raw_output = '[CONTENT] {"a": "b"}'
|
||||
target_output = '[CONTENT] {"a": "b"}\n[/CONTENT]'
|
||||
|
||||
output = repair_llm_raw_output(output=raw_output, req_keys=["[/CONTENT]"])
|
||||
assert output == target_output
|
||||
|
||||
raw_output = """[CONTENT]
|
||||
{
|
||||
"key": "value"
|
||||
]"""
|
||||
target_output = """[CONTENT]
|
||||
{
|
||||
"key": "value"
|
||||
]
|
||||
[/CONTENT]"""
|
||||
|
||||
output = repair_llm_raw_output(output=raw_output, req_keys=["[/CONTENT]"])
|
||||
assert output == target_output
|
||||
|
||||
raw_output = """[CONTENT] tag
|
||||
[CONTENT]
|
||||
{
|
||||
"key": "value"
|
||||
}
|
||||
xxx
|
||||
"""
|
||||
target_output = """[CONTENT]
|
||||
{
|
||||
"key": "value"
|
||||
}
|
||||
[/CONTENT]"""
|
||||
output = repair_llm_raw_output(output=raw_output, req_keys=["[/CONTENT]"])
|
||||
assert output == target_output
|
||||
|
||||
|
||||
def test_repair_json_format():
|
||||
raw_output = "{ xxx }]"
|
||||
target_output = "{ xxx }"
|
||||
|
||||
output = repair_llm_raw_output(output=raw_output, req_keys=[None], repair_type=RepairType.JSON)
|
||||
assert output == target_output
|
||||
|
||||
raw_output = "[{ xxx }"
|
||||
target_output = "{ xxx }"
|
||||
|
||||
output = repair_llm_raw_output(output=raw_output, req_keys=[None], repair_type=RepairType.JSON)
|
||||
assert output == target_output
|
||||
|
||||
raw_output = "{ xxx ]"
|
||||
target_output = "{ xxx }"
|
||||
|
||||
output = repair_llm_raw_output(output=raw_output, req_keys=[None], repair_type=RepairType.JSON)
|
||||
assert output == target_output
|
||||
|
||||
|
||||
def test_repair_invalid_json():
|
||||
raw_output = """{
|
||||
"key": "value"
|
||||
},
|
||||
}"""
|
||||
target_output = """{
|
||||
"key": "value"
|
||||
,
|
||||
}"""
|
||||
output = repair_invalid_json(raw_output, "Expecting ',' delimiter: line 3 column 1")
|
||||
assert output == target_output
|
||||
|
||||
raw_output = """{
|
||||
"key": "
|
||||
value
|
||||
},
|
||||
}"""
|
||||
target_output = """{
|
||||
"key": "
|
||||
value
|
||||
",
|
||||
}"""
|
||||
output = repair_invalid_json(raw_output, "Expecting ',' delimiter: line 4 column 1")
|
||||
output = repair_invalid_json(output, "Expecting ',' delimiter: line 4 column 1")
|
||||
assert output == target_output
|
||||
|
||||
raw_output = """{
|
||||
"key": '
|
||||
value
|
||||
},
|
||||
}"""
|
||||
target_output = """{
|
||||
"key": '
|
||||
value
|
||||
',
|
||||
}"""
|
||||
output = repair_invalid_json(raw_output, "Expecting ',' delimiter: line 4 column 1")
|
||||
output = repair_invalid_json(output, "Expecting ',' delimiter: line 4 column 1")
|
||||
output = repair_invalid_json(output, "Expecting ',' delimiter: line 4 column 1")
|
||||
assert output == target_output
|
||||
|
||||
|
||||
def test_retry_parse_json_text():
|
||||
invalid_json_text = """{
|
||||
"Original Requirements": "Create a 2048 game",
|
||||
"Competitive Quadrant Chart": "quadrantChart\n\ttitle Reach and engagement of campaigns\n\t\tx-axis"
|
||||
],
|
||||
"Requirement Analysis": "The requirements are clear and well-defined"
|
||||
}"""
|
||||
target_json = {
|
||||
"Original Requirements": "Create a 2048 game",
|
||||
"Competitive Quadrant Chart": "quadrantChart\n\ttitle Reach and engagement of campaigns\n\t\tx-axis",
|
||||
"Requirement Analysis": "The requirements are clear and well-defined",
|
||||
}
|
||||
output = retry_parse_json_text(output=invalid_json_text)
|
||||
assert output == target_json
|
||||
|
||||
invalid_json_text = """{
|
||||
"Original Requirements": "Create a 2048 game",
|
||||
"Competitive Quadrant Chart": "quadrantChart\n\ttitle Reach and engagement of campaigns\n\t\tx-axis"
|
||||
},
|
||||
"Requirement Analysis": "The requirements are clear and well-defined"
|
||||
}"""
|
||||
target_json = {
|
||||
"Original Requirements": "Create a 2048 game",
|
||||
"Competitive Quadrant Chart": "quadrantChart\n\ttitle Reach and engagement of campaigns\n\t\tx-axis",
|
||||
"Requirement Analysis": "The requirements are clear and well-defined",
|
||||
}
|
||||
output = retry_parse_json_text(output=invalid_json_text)
|
||||
assert output == target_json
|
||||
|
||||
|
||||
def test_extract_content_from_output():
|
||||
"""
|
||||
cases
|
||||
xxx [CONTENT] xxxx [/CONTENT]
|
||||
xxx [CONTENT] xxx [CONTENT] xxxx [/CONTENT]
|
||||
xxx [CONTENT] xxxx [/CONTENT] xxx [CONTENT][/CONTENT] xxx [CONTENT][/CONTENT] # target pair is the last one
|
||||
"""
|
||||
|
||||
output = (
|
||||
'Sure! Here is the properly formatted JSON output based on the given context:\n\n[CONTENT]\n{\n"'
|
||||
'Required Python third-party packages": [\n"pygame==2.0.4",\n"pytest"\n],\n"Required Other language '
|
||||
'third-party packages": [\n"No third-party packages are required."\n],\n"Full API spec": "\nopenapi: '
|
||||
"3.0.0\n\ndescription: A JSON object representing the game state.\n\npaths:\n game:\n get:\n "
|
||||
"summary: Get the current game state.\n responses:\n 200:\n description: Game state."
|
||||
"\n\n moves:\n post:\n summary: Make a move.\n requestBody:\n description: Move to be "
|
||||
"made.\n content:\n applicationjson:\n schema:\n type: object\n "
|
||||
" properties:\n x:\n type: integer\n y:\n "
|
||||
" type: integer\n tile:\n type: object\n "
|
||||
"properties:\n value:\n type: integer\n x:\n "
|
||||
" type: integer\n y:\n type: integer\n\n "
|
||||
"undo-move:\n post:\n summary: Undo the last move.\n responses:\n 200:\n "
|
||||
" description: Undone move.\n\n end-game:\n post:\n summary: End the game.\n responses:\n "
|
||||
" 200:\n description: Game ended.\n\n start-game:\n post:\n summary: Start a new "
|
||||
"game.\n responses:\n 200:\n description: Game started.\n\n game-over:\n get:\n "
|
||||
" summary: Check if the game is over.\n responses:\n 200:\n description: Game "
|
||||
"over.\n 404:\n description: Game not over.\n\n score:\n get:\n summary: Get the "
|
||||
"current score.\n responses:\n 200:\n description: Score.\n\n tile:\n get:\n "
|
||||
"summary: Get a specific tile.\n parameters:\n tile_id:\n type: integer\n "
|
||||
"description: ID of the tile to get.\n responses:\n 200:\n description: Tile.\n\n "
|
||||
"tiles:\n get:\n summary: Get all tiles.\n responses:\n 200:\n description: "
|
||||
"Tiles.\n\n level:\n get:\n summary: Get the current level.\n responses:\n 200:\n "
|
||||
" description: Level.\n\n level-up:\n post:\n summary: Level up.\n responses:\n "
|
||||
"200:\n description: Level up successful.\n\n level-down:\n post:\n summary: Level "
|
||||
"down.\n responses:\n 200:\n description: Level down successful.\n\n restart:\n "
|
||||
"post:\n summary: Restart the game.\n responses:\n 200:\n description: Game "
|
||||
"restarted.\n\n help:\n get:\n summary: Get help.\n responses:\n 200:\n "
|
||||
"description: Help.\n\n version:\n get:\n summary: Get the version of the game.\n "
|
||||
'responses:\n 200:\n description: Version.\n\n}\n\n"Logic Analysis": [\n"game.py",'
|
||||
'\n"Contains the game logic."\n],\n"Task list": [\n"game.py",\n"Contains the game logic and should be '
|
||||
'done first."\n],\n"Shared Knowledge": "\n\'game.py\' contains the game logic.\n",\n"Anything '
|
||||
'UNCLEAR": "How to start the game."\n]\n\n[/CONTENT] Great! Your JSON output is properly formatted '
|
||||
"and correctly includes all the required sections. Here's a breakdown of what each section "
|
||||
"contains:\n\nRequired Python third-party packages:\n\n* pygame==2.0.4\n* pytest\n\nRequired Other "
|
||||
"language third-party packages:\n\n* No third-party packages are required.\n\nFull API spec:\n\n* "
|
||||
"openapi: 3.0.0\n* description: A JSON object representing the game state.\n* paths:\n + game: "
|
||||
"Get the current game state.\n + moves: Make a move.\n + undo-move: Undo the last move.\n + "
|
||||
"end-game: End the game.\n + start-game: Start a new game.\n + game-over: Check if the game is "
|
||||
"over.\n + score: Get the current score.\n + tile: Get a specific tile.\n + tiles: Get all tiles.\n "
|
||||
"+ level: Get the current level.\n + level-up: Level up.\n + level-down: Level down.\n + restart: "
|
||||
"Restart the game.\n + help: Get help.\n + version: Get the version of the game.\n\nLogic "
|
||||
"Analysis:\n\n* game.py contains the game logic.\n\nTask list:\n\n* game.py contains the game logic "
|
||||
"and should be done first.\n\nShared Knowledge:\n\n* 'game.py' contains the game logic.\n\nAnything "
|
||||
"UNCLEAR:\n\n* How to start the game.\n\nGreat job! This JSON output should provide a clear and "
|
||||
"comprehensive overview of the project's requirements and dependencies."
|
||||
)
|
||||
output = extract_content_from_output(output)
|
||||
assert output.startswith('{\n"Required Python third-party packages') and output.endswith(
|
||||
'UNCLEAR": "How to start the game."\n]'
|
||||
)
|
||||
|
||||
output = (
|
||||
"Sure, I would be happy to help! Here is the information you provided, formatted as a JSON object "
|
||||
'inside the [CONTENT] tag:\n\n[CONTENT]\n{\n"Original Requirements": "Create a 2048 game",\n"Search '
|
||||
'Information": "Search results for 2048 game",\n"Requirements": [\n"Create a game with the same rules '
|
||||
'as the original 2048 game",\n"Implement a user interface that is easy to use and understand",\n"Add a '
|
||||
'scoreboard to track the player progress",\n"Allow the player to undo and redo moves",\n"Implement a '
|
||||
'game over screen to display the final score"\n],\n"Product Goals": [\n"Create a fun and engaging game '
|
||||
'experience for the player",\n"Design a user interface that is visually appealing and easy to use",\n"'
|
||||
'Optimize the game for performance and responsiveness"\n],\n"User Stories": [\n"As a player, I want to '
|
||||
'be able to move tiles around the board to combine numbers",\n"As a player, I want to be able to undo '
|
||||
'and redo moves to correct mistakes",\n"As a player, I want to see the final score and game over screen'
|
||||
' when I win"\n],\n"Competitive Analysis": [\n"Competitor A: 2048 game with a simple user interface and'
|
||||
' basic graphics",\n"Competitor B: 2048 game with a more complex user interface and better graphics",'
|
||||
'\n"Competitor C: 2048 game with a unique twist on the rules and a more challenging gameplay experience"'
|
||||
'\n],\n"Competitive Quadrant Chart": "quadrantChart\\n\ttitle Reach and engagement of campaigns\\n\t\t'
|
||||
"x-axis Low Reach --> High Reach\\n\t\ty-axis Low Engagement --> High Engagement\\n\tquadrant-1 We "
|
||||
"should expand\\n\tquadrant-2 Need to promote\\n\tquadrant-3 Re-evaluate\\n\tquadrant-4 May be "
|
||||
"improved\\n\tCampaign A: [0.3, 0.6]\\n\tCampaign B: [0.45, 0.23]\\n\tCampaign C: [0.57, 0.69]\\n\t"
|
||||
'Campaign D: [0.78, 0.34]\\n\tCampaign E: [0.40, 0.34]\\n\tCampaign F: [0.35, 0.78]"\n],\n"Requirement '
|
||||
'Analysis": "The requirements are clear and well-defined, but there may be some ambiguity around the '
|
||||
'specific implementation details",\n"Requirement Pool": [\n["P0", "Implement a game with the same '
|
||||
'rules as the original 2048 game"],\n["P1", "Add a scoreboard to track the player progress"],\n["P2", '
|
||||
'"Allow the player to undo and redo moves"]\n],\n"UI Design draft": "The UI should be simple and easy '
|
||||
"to use, with a clean and visually appealing design. The game board should be the main focus of the "
|
||||
'UI, with clear and concise buttons for the player to interact with.",\n"Anything UNCLEAR": ""\n}\n'
|
||||
"[/CONTENT]\n\nI hope this helps! Let me know if you have any further questions or if there anything "
|
||||
"else I can do to assist you."
|
||||
)
|
||||
output = extract_content_from_output(output)
|
||||
assert output.startswith('{\n"Original Requirements"') and output.endswith('"Anything UNCLEAR": ""\n}')
|
||||
|
||||
output = """ Sure, I'd be happy to help! Here's the JSON output for the given context:\n\n[CONTENT]\n{
|
||||
"Implementation approach": "We will use the open-source framework PyGame to create a 2D game engine, which will
|
||||
provide us with a robust and efficient way to handle game logic and rendering. PyGame is widely used in the game
|
||||
development community and has a large number of resources and tutorials available online.",\n"Python package name":
|
||||
"pygame_2048",\n"File list": ["main.py", "game.py", "constants.py", "ui.py"],\n"Data structures and interface
|
||||
definitions": '\nclassDiagram\n class Game{\n +int score\n +list<tile> tiles\n +function
|
||||
move_tile(tile, int dx, int dy)\n +function undo_move()\n +function get_highest_score()\n }\n
|
||||
class Tile{\n +int value\n +int x\n +int y\n }\n ...\n Game "1" -- "1" Food: has\n',
|
||||
\n"Program call flow": '\nsequenceDiagram\n participant M as Main\n participant G as Game\n ...\n G->>M:
|
||||
end game\n',\n"Anything UNCLEAR": "The requirement is clear to me."\n}\n[/CONTENT] Here's the JSON output for the
|
||||
given context, wrapped inside the [CONTENT][/CONTENT] format:\n\n[CONTENT]\n{\n"Implementation approach": "We will
|
||||
use the open-source framework PyGame to create a 2D game engine, which will provide us with a robust and efficient
|
||||
way to handle game logic and rendering. PyGame is widely used in the game development community and has a large
|
||||
number of resources and tutorials available online.",\n"Python package name": "pygame_2048",\n"File list":
|
||||
["main.py", "game.py", "constants.py", "ui.py"],\n"Data structures and interface definitions": '\nclassDiagram\n
|
||||
class Game{\n +int score\n +list<tile> tiles\n +function move_tile(tile, int dx, int dy)\n
|
||||
+function undo_move()\n +function get_highest_score()\n }\n class Tile{\n +int value\n +int x\n
|
||||
+int y\n }\n ...\n Game "1" -- "1" Food: has\n',\n"Program call flow": '\nsequenceDiagram\n participant
|
||||
M as Main\n participant G as Game\n ...\n G->>M: end game\n',\n"Anything UNCLEAR": "The requirement is
|
||||
clear to me."\n}\n[/CONTENT] Great! Your JSON output is well-formatted and provides all the necessary
|
||||
information for a developer to understand the design and implementation of the 2048 game.
|
||||
"""
|
||||
output = extract_content_from_output(output)
|
||||
assert output.startswith('{\n"Implementation approach"') and output.endswith(
|
||||
'"Anything UNCLEAR": "The requirement is clear to me."\n}'
|
||||
)
|
||||
|
|
@ -1,11 +1,13 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the unittest of serialize
|
||||
"""
|
||||
@Desc : the unittest of serialize
|
||||
"""
|
||||
|
||||
from typing import List, Tuple
|
||||
|
||||
from metagpt.actions import WritePRD
|
||||
from metagpt.actions.action_output import ActionOutput
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.serialize import (
|
||||
actionoutout_schema_to_mapping,
|
||||
|
|
@ -52,7 +54,7 @@ def test_actionoutout_schema_to_mapping():
|
|||
def test_serialize_and_deserialize_message():
|
||||
out_mapping = {"field1": (str, ...), "field2": (List[str], ...)}
|
||||
out_data = {"field1": "field1 value", "field2": ["field2 value1", "field2 value2"]}
|
||||
ic_obj = ActionOutput.create_model_class("prd", out_mapping)
|
||||
ic_obj = ActionNode.create_model_class("prd", out_mapping)
|
||||
|
||||
message = Message(
|
||||
content="prd demand", instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ def _paragraphs(n):
|
|||
(_msgs(), "gpt-4", "Hello," * 1000, 2000, 2),
|
||||
(_msgs(), "gpt-4-32k", "System", 4000, 14),
|
||||
(_msgs(), "gpt-4-32k", "Hello," * 2000, 4000, 12),
|
||||
]
|
||||
],
|
||||
)
|
||||
def test_reduce_message_length(msgs, model_name, system_text, reserved, expected):
|
||||
assert len(reduce_message_length(msgs, model_name, system_text, reserved)) / (len("Hello,")) / 1000 == expected
|
||||
|
|
@ -42,7 +42,7 @@ def test_reduce_message_length(msgs, model_name, system_text, reserved, expected
|
|||
(" ".join("Hello World." for _ in range(1000)), "Prompt: {}", "gpt-3.5-turbo-16k", "System", 3000, 1),
|
||||
(" ".join("Hello World." for _ in range(4000)), "Prompt: {}", "gpt-4", "System", 2000, 2),
|
||||
(" ".join("Hello World." for _ in range(8000)), "Prompt: {}", "gpt-4-32k", "System", 4000, 1),
|
||||
]
|
||||
],
|
||||
)
|
||||
def test_generate_prompt_chunk(text, prompt_template, model_name, system_text, reserved, expected):
|
||||
ret = list(generate_prompt_chunk(text, prompt_template, model_name, system_text, reserved))
|
||||
|
|
@ -58,7 +58,7 @@ def test_generate_prompt_chunk(text, prompt_template, model_name, system_text, r
|
|||
("......", ".", 2, ["...", "..."]),
|
||||
("......", ".", 3, ["..", "..", ".."]),
|
||||
(".......", ".", 2, ["....", "..."]),
|
||||
]
|
||||
],
|
||||
)
|
||||
def test_split_paragraph(paragraph, sep, count, expected):
|
||||
ret = split_paragraph(paragraph, sep, count)
|
||||
|
|
@ -71,7 +71,7 @@ def test_split_paragraph(paragraph, sep, count, expected):
|
|||
("Hello\\nWorld", "Hello\nWorld"),
|
||||
("Hello\\tWorld", "Hello\tWorld"),
|
||||
("Hello\\u0020World", "Hello World"),
|
||||
]
|
||||
],
|
||||
)
|
||||
def test_decode_unicode_escape(text, expected):
|
||||
assert decode_unicode_escape(text) == expected
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue