Merge main branch

This commit is contained in:
mannaandpoem 2024-01-03 19:48:46 +08:00
commit 24e617b362
325 changed files with 11290 additions and 3760 deletions

View file

@ -0,0 +1,143 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/12/24 20:32
@Author : alexanderwu
@File : mock_json.py
"""
PRD = {
"Language": "zh_cn",
"Programming Language": "Python",
"Original Requirements": "写一个简单的cli贪吃蛇",
"Project Name": "cli_snake",
"Product Goals": ["创建一个简单易用的贪吃蛇游戏", "提供良好的用户体验", "支持不同难度级别"],
"User Stories": [
"作为玩家,我希望能够选择不同的难度级别",
"作为玩家,我希望在每局游戏结束后能够看到我的得分",
"作为玩家,我希望在输掉游戏后能够重新开始",
"作为玩家,我希望看到简洁美观的界面",
"作为玩家,我希望能够在手机上玩游戏",
],
"Competitive Analysis": ["贪吃蛇游戏A界面简单缺乏响应式特性", "贪吃蛇游戏B美观且响应式的界面显示最高得分", "贪吃蛇游戏C响应式界面显示最高得分但有很多广告"],
"Competitive Quadrant Chart": 'quadrantChart\n title "Reach and engagement of campaigns"\n x-axis "Low Reach" --> "High Reach"\n y-axis "Low Engagement" --> "High Engagement"\n quadrant-1 "We should expand"\n quadrant-2 "Need to promote"\n quadrant-3 "Re-evaluate"\n quadrant-4 "May be improved"\n "Game A": [0.3, 0.6]\n "Game B": [0.45, 0.23]\n "Game C": [0.57, 0.69]\n "Game D": [0.78, 0.34]\n "Game E": [0.40, 0.34]\n "Game F": [0.35, 0.78]\n "Our Target Product": [0.5, 0.6]',
"Requirement Analysis": "",
"Requirement Pool": [["P0", "主要代码..."], ["P0", "游戏算法..."]],
"UI Design draft": "基本功能描述,简单的风格和布局。",
"Anything UNCLEAR": "",
}
DESIGN = {
"Implementation approach": "我们将使用Python编程语言并选择合适的开源框架来实现贪吃蛇游戏。我们将分析需求中的难点并选择合适的开源框架来简化开发流程。",
"File list": ["main.py", "game.py"],
"Data structures and interfaces": "\nclassDiagram\n class Game {\n -int width\n -int height\n -int score\n -int speed\n -List<Point> snake\n -Point food\n +__init__(width: int, height: int, speed: int)\n +start_game()\n +change_direction(direction: str)\n +game_over()\n +update_snake()\n +update_food()\n +check_collision()\n }\n class Point {\n -int x\n -int y\n +__init__(x: int, y: int)\n }\n Game --> Point\n",
"Program call flow": "\nsequenceDiagram\n participant M as Main\n participant G as Game\n M->>G: start_game()\n M->>G: change_direction(direction)\n G->>G: update_snake()\n G->>G: update_food()\n G->>G: check_collision()\n G-->>G: game_over()\n",
"Anything UNCLEAR": "",
}
TASKS = {
"Required Python packages": ["pygame==2.0.1"],
"Required Other language third-party packages": ["No third-party dependencies required"],
"Logic Analysis": [
["game.py", "Contains Game class and related functions for game logic"],
["main.py", "Contains the main function, imports Game class from game.py"],
],
"Task list": ["game.py", "main.py"],
"Full API spec": "",
"Shared Knowledge": "'game.py' contains functions shared across the project.",
"Anything UNCLEAR": "",
}
FILE_GAME = """## game.py
import pygame
import random
class Point:
def __init__(self, x: int, y: int):
self.x = x
self.y = y
class Game:
def __init__(self, width: int, height: int, speed: int):
self.width = width
self.height = height
self.score = 0
self.speed = speed
self.snake = [Point(width // 2, height // 2)]
self.food = self._create_food()
def start_game(self):
pygame.init()
self._display = pygame.display.set_mode((self.width, self.height))
pygame.display.set_caption('Snake Game')
self._clock = pygame.time.Clock()
self._running = True
while self._running:
self._handle_events()
self._update_snake()
self._update_food()
self._check_collision()
self._draw_screen()
self._clock.tick(self.speed)
def change_direction(self, direction: str):
# Update the direction of the snake based on user input
pass
def game_over(self):
# Display game over message and handle game over logic
pass
def _create_food(self) -> Point:
# Create and return a new food Point
return Point(random.randint(0, self.width - 1), random.randint(0, self.height - 1))
def _handle_events(self):
for event in pygame.event.get():
if event.type == pygame.QUIT:
self._running = False
def _update_snake(self):
# Update the position of the snake based on its direction
pass
def _update_food(self):
# Update the position of the food if the snake eats it
pass
def _check_collision(self):
# Check for collision between the snake and the walls or itself
pass
def _draw_screen(self):
self._display.fill((0, 0, 0)) # Clear the screen
# Draw the snake and food on the screen
pygame.display.update()
if __name__ == "__main__":
game = Game(800, 600, 15)
game.start_game()
"""
FILE_GAME_CR_1 = """## Code Review: game.py
1. Yes, the code is implemented as per the requirements. It initializes the game with the specified width, height, and speed, and starts the game loop.
2. No, the logic for handling events and updating the snake, food, and collision is not implemented. To correct this, we need to implement the logic for handling events, updating the snake and food positions, and checking for collisions.
3. Yes, the existing code follows the "Data structures and interfaces" by defining the Game and Point classes with the specified attributes and methods.
4. No, several functions such as change_direction, game_over, _update_snake, _update_food, and _check_collision are not implemented. These functions need to be implemented to complete the game logic.
5. Yes, all necessary pre-dependencies have been imported. The required pygame package is imported at the beginning of the file.
6. No, methods from other files are not being reused as there are no other files being imported or referenced in the current code.
## Actions
1. Implement the logic for handling events, updating the snake and food positions, and checking for collisions within the Game class.
2. Implement the change_direction and game_over methods to handle user input and game over logic.
3. Implement the _update_snake method to update the position of the snake based on its direction.
4. Implement the _update_food method to update the position of the food if the snake eats it.
5. Implement the _check_collision method to check for collision between the snake and the walls or itself.
## Code Review Result
LBTM"""

View file

@ -3,7 +3,7 @@
"""
@Time : 2023/5/18 23:51
@Author : alexanderwu
@File : mock.py
@File : mock_markdown.py
"""
PRD_SAMPLE = """## Original Requirements

View file

@ -5,9 +5,37 @@
@Author : alexanderwu
@File : test_action.py
"""
from metagpt.actions import Action, WritePRD, WriteTest
import pytest
from metagpt.actions import Action, ActionType, WritePRD, WriteTest
def test_action_repr():
actions = [Action(), WriteTest(), WritePRD()]
assert "WriteTest" in str(actions)
def test_action_type():
assert ActionType.WRITE_PRD.value == WritePRD
assert ActionType.WRITE_TEST.value == WriteTest
assert ActionType.WRITE_PRD.name == "WRITE_PRD"
assert ActionType.WRITE_TEST.name == "WRITE_TEST"
def test_simple_action():
action = Action(name="AlexSay", instruction="Express your opinion with emotion and don't repeat it")
assert action.name == "AlexSay"
assert action.node.instruction == "Express your opinion with emotion and don't repeat it"
def test_empty_action():
action = Action()
assert action.name == "Action"
assert not action.node
@pytest.mark.asyncio
async def test_empty_action_exception():
action = Action()
with pytest.raises(NotImplementedError):
await action.run()

View file

@ -5,11 +5,15 @@
@Author : alexanderwu
@File : test_action_node.py
"""
from typing import List, Tuple
import pytest
from pydantic import ValidationError
from metagpt.actions import Action
from metagpt.actions.action_node import ActionNode
from metagpt.environment import Environment
from metagpt.llm import LLM
from metagpt.roles import Role
from metagpt.schema import Message
from metagpt.team import Team
@ -17,35 +21,35 @@ from metagpt.team import Team
@pytest.mark.asyncio
async def test_debate_two_roles():
action1 = Action(name="BidenSay", instruction="Express opinions and argue vigorously, and strive to gain votes")
action2 = Action(name="TrumpSay", instruction="Express opinions and argue vigorously, and strive to gain votes")
action1 = Action(name="AlexSay", instruction="Express your opinion with emotion and don't repeat it")
action2 = Action(name="BobSay", instruction="Express your opinion with emotion and don't repeat it")
biden = Role(
name="Biden", profile="Democratic candidate", goal="Win the election", actions=[action1], watch=[action2]
name="Alex", profile="Democratic candidate", goal="Win the election", actions=[action1], watch=[action2]
)
trump = Role(
name="Trump", profile="Republican candidate", goal="Win the election", actions=[action2], watch=[action1]
name="Bob", profile="Republican candidate", goal="Win the election", actions=[action2], watch=[action1]
)
env = Environment(desc="US election live broadcast")
team = Team(investment=10.0, env=env, roles=[biden, trump])
history = await team.run(idea="Topic: climate change. Under 80 words per message.", send_to="Biden", n_round=3)
assert "BidenSay" in history
history = await team.run(idea="Topic: climate change. Under 80 words per message.", send_to="Alex", n_round=3)
assert "Alex" in history
@pytest.mark.asyncio
async def test_debate_one_role_in_env():
action = Action(name="Debate", instruction="Express opinions and argue vigorously, and strive to gain votes")
biden = Role(name="Biden", profile="Democratic candidate", goal="Win the election", actions=[action])
action = Action(name="Debate", instruction="Express your opinion with emotion and don't repeat it")
biden = Role(name="Alex", profile="Democratic candidate", goal="Win the election", actions=[action])
env = Environment(desc="US election live broadcast")
team = Team(investment=10.0, env=env, roles=[biden])
history = await team.run(idea="Topic: climate change. Under 80 words per message.", send_to="Biden", n_round=3)
assert "Debate" in history
history = await team.run(idea="Topic: climate change. Under 80 words per message.", send_to="Alex", n_round=3)
assert "Alex" in history
@pytest.mark.asyncio
async def test_debate_one_role():
action = Action(name="Debate", instruction="Express opinions and argue vigorously, and strive to gain votes")
biden = Role(name="Biden", profile="Democratic candidate", goal="Win the election", actions=[action])
action = Action(name="Debate", instruction="Express your opinion with emotion and don't repeat it")
biden = Role(name="Alex", profile="Democratic candidate", goal="Win the election", actions=[action])
msg: Message = await biden.run("Topic: climate change. Under 80 words per message.")
assert len(msg.content) > 10
@ -74,15 +78,94 @@ async def test_action_node_one_layer():
assert "key-a" in markdown_template
assert node_dict["key-a"] == "instruction-b"
assert "key-a" in repr(node)
@pytest.mark.asyncio
async def test_action_node_two_layer():
node_a = ActionNode(key="key-a", expected_type=str, instruction="i-a", example="e-a")
node_b = ActionNode(key="key-b", expected_type=str, instruction="i-b", example="e-b")
node_a = ActionNode(key="reasoning", expected_type=str, instruction="reasoning step by step", example="")
node_b = ActionNode(key="answer", expected_type=str, instruction="the final answer", example="")
root = ActionNode.from_children(key="", nodes=[node_a, node_b])
assert "key-a" in root.children
root = ActionNode.from_children(key="detail answer", nodes=[node_a, node_b])
assert "reasoning" in root.children
assert node_b in root.children.values()
json_template = root.compile(context="123", schema="json", mode="auto")
assert "i-a" in json_template
# FIXME: ADD MARKDOWN SUPPORT. NEED TO TUNE MARKDOWN SYMBOL FIRST.
answer1 = await root.fill(context="what's the answer to 123+456?", schema="json", strgy="simple", llm=LLM())
assert "579" in answer1.content
answer2 = await root.fill(context="what's the answer to 123+456?", schema="json", strgy="complex", llm=LLM())
assert "579" in answer2.content
t_dict = {
"Required Python third-party packages": '"""\nflask==1.1.2\npygame==2.0.1\n"""\n',
"Required Other language third-party packages": '"""\nNo third-party packages required for other languages.\n"""\n',
"Full API spec": '"""\nopenapi: 3.0.0\ninfo:\n title: Web Snake Game API\n version: 1.0.0\npaths:\n /game:\n get:\n summary: Get the current game state\n responses:\n \'200\':\n description: A JSON object of the game state\n post:\n summary: Send a command to the game\n requestBody:\n required: true\n content:\n application/json:\n schema:\n type: object\n properties:\n command:\n type: string\n responses:\n \'200\':\n description: A JSON object of the updated game state\n"""\n',
"Logic Analysis": [
["app.py", "Main entry point for the Flask application. Handles HTTP requests and responses."],
["game.py", "Contains the Game and Snake classes. Handles the game logic."],
["static/js/script.js", "Handles user interactions and updates the game UI."],
["static/css/styles.css", "Defines the styles for the game UI."],
["templates/index.html", "The main page of the web application. Displays the game UI."],
],
"Task list": ["game.py", "app.py", "static/css/styles.css", "static/js/script.js", "templates/index.html"],
"Shared Knowledge": "\"\"\"\n'game.py' contains the Game and Snake classes which are responsible for the game logic. The Game class uses an instance of the Snake class.\n\n'app.py' is the main entry point for the Flask application. It creates an instance of the Game class and handles HTTP requests and responses.\n\n'static/js/script.js' is responsible for handling user interactions and updating the game UI based on the game state returned by 'app.py'.\n\n'static/css/styles.css' defines the styles for the game UI.\n\n'templates/index.html' is the main page of the web application. It displays the game UI and loads 'static/js/script.js' and 'static/css/styles.css'.\n\"\"\"\n",
"Anything UNCLEAR": "We need clarification on how the high score should be stored. Should it persist across sessions (stored in a database or a file) or should it reset every time the game is restarted? Also, should the game speed increase as the snake grows, or should it remain constant throughout the game?",
}
t_dict_min = {
"Required Python third-party packages": '"""\nflask==1.1.2\npygame==2.0.1\n"""\n',
}
WRITE_TASKS_OUTPUT_MAPPING = {
"Required Python third-party packages": (str, ...),
"Required Other language third-party packages": (str, ...),
"Full API spec": (str, ...),
"Logic Analysis": (List[Tuple[str, str]], ...),
"Task list": (List[str], ...),
"Shared Knowledge": (str, ...),
"Anything UNCLEAR": (str, ...),
}
WRITE_TASKS_OUTPUT_MAPPING_MISSING = {
"Required Python third-party packages": (str, ...),
}
def test_create_model_class():
test_class = ActionNode.create_model_class("test_class", WRITE_TASKS_OUTPUT_MAPPING)
assert test_class.__name__ == "test_class"
output = test_class(**t_dict)
print(output.schema())
assert output.schema()["title"] == "test_class"
assert output.schema()["type"] == "object"
assert output.schema()["properties"]["Full API spec"]
def test_create_model_class_with_fields_unrecognized():
test_class = ActionNode.create_model_class("test_class", WRITE_TASKS_OUTPUT_MAPPING_MISSING)
assert test_class.__name__ == "test_class"
_ = test_class(**t_dict) # just warning
def test_create_model_class_with_fields_missing():
test_class = ActionNode.create_model_class("test_class", WRITE_TASKS_OUTPUT_MAPPING)
assert test_class.__name__ == "test_class"
with pytest.raises(ValidationError):
_ = test_class(**t_dict_min)
def test_create_model_class_with_mapping():
t = ActionNode.create_model_class("test_class_1", WRITE_TASKS_OUTPUT_MAPPING)
t1 = t(**t_dict)
value = t1.model_dump()["Task list"]
assert value == ["game.py", "app.py", "static/css/styles.css", "static/js/script.js", "templates/index.html"]
if __name__ == "__main__":
test_create_model_class()
test_create_model_class_with_mapping()

View file

@ -1,53 +0,0 @@
#!/usr/bin/env python
# coding: utf-8
"""
@Time : 2023/7/11 10:49
@Author : chengmaoyu
@File : test_action_output
"""
from typing import List, Tuple
from metagpt.actions.action_node import ActionNode
t_dict = {
"Required Python third-party packages": '"""\nflask==1.1.2\npygame==2.0.1\n"""\n',
"Required Other language third-party packages": '"""\nNo third-party packages required for other languages.\n"""\n',
"Full API spec": '"""\nopenapi: 3.0.0\ninfo:\n title: Web Snake Game API\n version: 1.0.0\npaths:\n /game:\n get:\n summary: Get the current game state\n responses:\n \'200\':\n description: A JSON object of the game state\n post:\n summary: Send a command to the game\n requestBody:\n required: true\n content:\n application/json:\n schema:\n type: object\n properties:\n command:\n type: string\n responses:\n \'200\':\n description: A JSON object of the updated game state\n"""\n',
"Logic Analysis": [
["app.py", "Main entry point for the Flask application. Handles HTTP requests and responses."],
["game.py", "Contains the Game and Snake classes. Handles the game logic."],
["static/js/script.js", "Handles user interactions and updates the game UI."],
["static/css/styles.css", "Defines the styles for the game UI."],
["templates/index.html", "The main page of the web application. Displays the game UI."],
],
"Task list": ["game.py", "app.py", "static/css/styles.css", "static/js/script.js", "templates/index.html"],
"Shared Knowledge": "\"\"\"\n'game.py' contains the Game and Snake classes which are responsible for the game logic. The Game class uses an instance of the Snake class.\n\n'app.py' is the main entry point for the Flask application. It creates an instance of the Game class and handles HTTP requests and responses.\n\n'static/js/script.js' is responsible for handling user interactions and updating the game UI based on the game state returned by 'app.py'.\n\n'static/css/styles.css' defines the styles for the game UI.\n\n'templates/index.html' is the main page of the web application. It displays the game UI and loads 'static/js/script.js' and 'static/css/styles.css'.\n\"\"\"\n",
"Anything UNCLEAR": "We need clarification on how the high score should be stored. Should it persist across sessions (stored in a database or a file) or should it reset every time the game is restarted? Also, should the game speed increase as the snake grows, or should it remain constant throughout the game?",
}
WRITE_TASKS_OUTPUT_MAPPING = {
"Required Python third-party packages": (str, ...),
"Required Other language third-party packages": (str, ...),
"Full API spec": (str, ...),
"Logic Analysis": (List[Tuple[str, str]], ...),
"Task list": (List[str], ...),
"Shared Knowledge": (str, ...),
"Anything UNCLEAR": (str, ...),
}
def test_create_model_class():
test_class = ActionNode.create_model_class("test_class", WRITE_TASKS_OUTPUT_MAPPING)
assert test_class.__name__ == "test_class"
def test_create_model_class_with_mapping():
t = ActionNode.create_model_class("test_class_1", WRITE_TASKS_OUTPUT_MAPPING)
t1 = t(**t_dict)
value = t1.dict()["Task list"]
assert value == ["game.py", "app.py", "static/css/styles.css", "static/js/script.js", "templates/index.html"]
if __name__ == "__main__":
test_create_model_class()
test_create_model_class_with_mapping()

View file

@ -1,16 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/7/1 22:50
@Author : alexanderwu
@File : test_azure_tts.py
"""
from metagpt.tools.azure_tts import AzureTTS
def test_azure_tts():
azure_tts = AzureTTS()
azure_tts.synthesize_speech("zh-CN", "zh-CN-YunxiNeural", "Boy", "你好,我是卡卡", "output.wav")
# 运行需要先配置 SUBSCRIPTION_KEY
# TODO: 这里如果要检验还要额外加上对应的asr才能确保前后生成是接近一致的但现在还没有

View file

@ -1,57 +0,0 @@
import pytest
from metagpt.actions.clone_function import CloneFunction, run_function_code
source_code = """
import pandas as pd
import ta
def user_indicator():
# 读取股票数据
stock_data = pd.read_csv('./tests/data/baba_stock.csv')
stock_data.head()
# 计算简单移动平均线
stock_data['SMA'] = ta.trend.sma_indicator(stock_data['Close'], window=6)
stock_data[['Date', 'Close', 'SMA']].head()
# 计算布林带
stock_data['bb_upper'], stock_data['bb_middle'], stock_data['bb_lower'] = ta.volatility.bollinger_hband_indicator(stock_data['Close'], window=20), ta.volatility.bollinger_mavg(stock_data['Close'], window=20), ta.volatility.bollinger_lband_indicator(stock_data['Close'], window=20)
stock_data[['Date', 'Close', 'bb_upper', 'bb_middle', 'bb_lower']].head()
"""
template_code = """
def stock_indicator(stock_path: str, indicators=['Simple Moving Average', 'BollingerBands', 'MACD]) -> pd.DataFrame:
import pandas as pd
# here is your code.
"""
def get_expected_res():
import pandas as pd
import ta
# 读取股票数据
stock_data = pd.read_csv("./tests/data/baba_stock.csv")
stock_data.head()
# 计算简单移动平均线
stock_data["SMA"] = ta.trend.sma_indicator(stock_data["Close"], window=6)
stock_data[["Date", "Close", "SMA"]].head()
# 计算布林带
stock_data["bb_upper"], stock_data["bb_middle"], stock_data["bb_lower"] = (
ta.volatility.bollinger_hband_indicator(stock_data["Close"], window=20),
ta.volatility.bollinger_mavg(stock_data["Close"], window=20),
ta.volatility.bollinger_lband_indicator(stock_data["Close"], window=20),
)
stock_data[["Date", "Close", "bb_upper", "bb_middle", "bb_lower"]].head()
return stock_data
@pytest.mark.asyncio
async def test_clone_function():
clone = CloneFunction()
code = await clone.run(template_code, source_code)
assert "def " in code
stock_path = "./tests/data/baba_stock.csv"
df, msg = run_function_code(code, "stock_indicator", stock_path)
assert not msg
expected_df = get_expected_res()
assert df.equals(expected_df)

View file

@ -117,6 +117,7 @@ if __name__ == '__main__':
@pytest.mark.asyncio
@pytest.mark.usefixtures("llm_mock")
async def test_debug_error():
CONFIG.src_workspace = CONFIG.git_repo.workdir / uuid.uuid4().hex
ctx = RunCodeContext(
@ -142,7 +143,7 @@ async def test_debug_error():
"Ran 5 tests in 0.007s\n\nFAILED (failures=1)\n;\n",
)
await FileRepository.save_file(
filename=ctx.output_filename, content=output_data.json(), relative_path=TEST_OUTPUTS_FILE_REPO
filename=ctx.output_filename, content=output_data.model_dump_json(), relative_path=TEST_OUTPUTS_FILE_REPO
)
debug_error = DebugError(context=ctx)

View file

@ -13,18 +13,19 @@ from metagpt.const import PRDS_FILE_REPO
from metagpt.logs import logger
from metagpt.schema import Message
from metagpt.utils.file_repository import FileRepository
from tests.metagpt.actions.mock import PRD_SAMPLE
from tests.metagpt.actions.mock_markdown import PRD_SAMPLE
@pytest.mark.asyncio
@pytest.mark.usefixtures("llm_mock")
async def test_design_api():
inputs = ["我们需要一个音乐播放器,它应该有播放、暂停、上一曲、下一曲等功能。", PRD_SAMPLE]
for prd in inputs:
await FileRepository.save_file("new_prd.txt", content=prd, relative_path=PRDS_FILE_REPO)
design_api = WriteDesign("design_api")
design_api = WriteDesign()
result = await design_api.run([Message(content=prd, instruct_content=None)])
result = await design_api.run(Message(content=prd, instruct_content=None))
logger.info(result)
assert result

View file

@ -11,6 +11,7 @@ from metagpt.actions.design_api_review import DesignReview
@pytest.mark.asyncio
@pytest.mark.usefixtures("llm_mock")
async def test_design_api_review():
prd = "我们需要一个音乐播放器,它应该有播放、暂停、上一曲、下一曲等功能。"
api_design = """
@ -26,7 +27,7 @@ API列表:
"""
_ = "API设计看起来非常合理满足了PRD中的所有需求。"
design_api_review = DesignReview("design_api_review")
design_api_review = DesignReview()
result = await design_api_review.run(prd, api_design)

View file

@ -0,0 +1,17 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/12/25 22:38
@Author : alexanderwu
@File : test_fix_bug.py
"""
import pytest
from metagpt.actions.fix_bug import FixBug
@pytest.mark.asyncio
async def test_fix_bug():
fix_bug = FixBug()
assert fix_bug.name == "FixBug"

View file

@ -20,6 +20,7 @@ context = """
@pytest.mark.asyncio
@pytest.mark.usefixtures("llm_mock")
async def test_generate_questions():
action = GenerateQuestions()
rsp = await action.run(context)

View file

@ -7,26 +7,25 @@
@File : test_invoice_ocr.py
"""
import os
from pathlib import Path
import pytest
from metagpt.actions.invoice_ocr import GenerateTable, InvoiceOCR, ReplyQuestion
from metagpt.const import TEST_DATA_PATH
@pytest.mark.asyncio
@pytest.mark.parametrize(
"invoice_path",
[
"../../data/invoices/invoice-3.jpg",
"../../data/invoices/invoice-4.zip",
Path("invoices/invoice-3.jpg"),
Path("invoices/invoice-4.zip"),
],
)
async def test_invoice_ocr(invoice_path: str):
invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path))
filename = os.path.basename(invoice_path)
resp = await InvoiceOCR().run(file_path=Path(invoice_path), filename=filename)
async def test_invoice_ocr(invoice_path: Path):
invoice_path = TEST_DATA_PATH / invoice_path
resp = await InvoiceOCR().run(file_path=Path(invoice_path))
assert isinstance(resp, list)
@ -34,25 +33,30 @@ async def test_invoice_ocr(invoice_path: str):
@pytest.mark.parametrize(
("invoice_path", "expected_result"),
[
("../../data/invoices/invoice-1.pdf", [{"收款人": "小明", "城市": "深圳", "总费用/元": "412.00", "开票日期": "2023年02月03日"}]),
(Path("invoices/invoice-1.pdf"), {"收款人": "小明", "城市": "深圳", "总费用/元": 412.00, "开票日期": "2023年02月03日"}),
],
)
async def test_generate_table(invoice_path: str, expected_result: list[dict]):
invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path))
filename = os.path.basename(invoice_path)
ocr_result = await InvoiceOCR().run(file_path=Path(invoice_path), filename=filename)
async def test_generate_table(invoice_path: Path, expected_result: dict):
invoice_path = TEST_DATA_PATH / invoice_path
filename = invoice_path.name
ocr_result = await InvoiceOCR().run(file_path=Path(invoice_path))
table_data = await GenerateTable().run(ocr_results=ocr_result, filename=filename)
assert table_data == expected_result
assert isinstance(table_data, list)
table_data = table_data[0]
assert expected_result["收款人"] == table_data["收款人"]
assert expected_result["城市"] in table_data["城市"]
assert float(expected_result["总费用/元"]) == float(table_data["总费用/元"])
assert expected_result["开票日期"] == table_data["开票日期"]
@pytest.mark.asyncio
@pytest.mark.parametrize(
("invoice_path", "query", "expected_result"),
[("../../data/invoices/invoice-1.pdf", "Invoicing date", "2023年02月03日")],
[(Path("invoices/invoice-1.pdf"), "Invoicing date", "2023年02月03日")],
)
async def test_reply_question(invoice_path: str, query: dict, expected_result: str):
invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path))
filename = os.path.basename(invoice_path)
ocr_result = await InvoiceOCR().run(file_path=Path(invoice_path), filename=filename)
@pytest.mark.usefixtures("llm_mock")
async def test_reply_question(invoice_path: Path, query: dict, expected_result: str):
invoice_path = TEST_DATA_PATH / invoice_path
ocr_result = await InvoiceOCR().run(file_path=Path(invoice_path))
result = await ReplyQuestion().run(query=query, ocr_result=ocr_result)
assert expected_result in result

View file

@ -12,6 +12,7 @@ from metagpt.logs import logger
@pytest.mark.asyncio
@pytest.mark.usefixtures("llm_mock")
async def test_prepare_interview():
action = PrepareInterview()
rsp = await action.run("I just graduated and hope to find a job as a Python engineer")

View file

@ -6,10 +6,27 @@
@File : test_project_management.py
"""
import pytest
class TestCreateProjectPlan:
pass
from metagpt.actions.project_management import WriteTasks
from metagpt.config import CONFIG
from metagpt.const import PRDS_FILE_REPO, SYSTEM_DESIGN_FILE_REPO
from metagpt.logs import logger
from metagpt.schema import Message
from metagpt.utils.file_repository import FileRepository
from tests.metagpt.actions.mock_json import DESIGN, PRD
class TestAssignTasks:
pass
@pytest.mark.asyncio
@pytest.mark.usefixtures("llm_mock")
async def test_design_api():
await FileRepository.save_file("1.txt", content=str(PRD), relative_path=PRDS_FILE_REPO)
await FileRepository.save_file("1.txt", content=str(DESIGN), relative_path=SYSTEM_DESIGN_FILE_REPO)
logger.info(CONFIG.git_repo)
action = WriteTasks()
result = await action.run(Message(content="", instruct_content=None))
logger.info(result)
assert result

View file

@ -0,0 +1,24 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/12/20
@Author : mashenquan
@File : test_rebuild_class_view.py
@Desc : Unit tests for rebuild_class_view.py
"""
from pathlib import Path
import pytest
from metagpt.actions.rebuild_class_view import RebuildClassView
from metagpt.llm import LLM
@pytest.mark.asyncio
async def test_rebuild():
action = RebuildClassView(name="RedBean", context=Path(__file__).parent.parent, llm=LLM())
await action.run()
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -1,6 +1,21 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/12/28
@Author : mashenquan
@File : test_research.py
"""
import pytest
from metagpt.actions import research
from metagpt.actions import CollectLinks, research
@pytest.mark.asyncio
async def test_action():
action = CollectLinks()
result = await action.run(topic="baidu")
assert result
@pytest.mark.asyncio
@ -17,7 +32,7 @@ async def test_collect_links(mocker):
elif "sort the remaining search results" in prompt:
return "[1,2]"
mocker.patch("metagpt.provider.base_gpt_api.BaseGPTAPI.aask", mock_llm_ask)
mocker.patch("metagpt.provider.base_llm.BaseLLM.aask", mock_llm_ask)
resp = await research.CollectLinks().run("The application of MetaGPT")
for i in ["MetaGPT use cases", "The roadmap of MetaGPT", "The function of MetaGPT", "What llm MetaGPT support"]:
assert i in resp
@ -36,7 +51,7 @@ async def test_collect_links_with_rank_func(mocker):
rank_after.append(results)
return results
mocker.patch("metagpt.provider.base_gpt_api.BaseGPTAPI.aask", mock_collect_links_llm_ask)
mocker.patch("metagpt.provider.base_llm.BaseLLM.aask", mock_collect_links_llm_ask)
resp = await research.CollectLinks(rank_func=rank_func).run("The application of MetaGPT")
for x, y, z in zip(rank_before, rank_after, resp.values()):
assert x[::-1] == y
@ -48,7 +63,7 @@ async def test_web_browse_and_summarize(mocker):
async def mock_llm_ask(*args, **kwargs):
return "metagpt"
mocker.patch("metagpt.provider.base_gpt_api.BaseGPTAPI.aask", mock_llm_ask)
mocker.patch("metagpt.provider.base_llm.BaseLLM.aask", mock_llm_ask)
url = "https://github.com/geekan/MetaGPT"
url2 = "https://github.com/trending"
query = "What's new in metagpt"
@ -64,7 +79,7 @@ async def test_web_browse_and_summarize(mocker):
async def mock_llm_ask(*args, **kwargs):
return "Not relevant."
mocker.patch("metagpt.provider.base_gpt_api.BaseGPTAPI.aask", mock_llm_ask)
mocker.patch("metagpt.provider.base_llm.BaseLLM.aask", mock_llm_ask)
resp = await research.WebBrowseAndSummarize().run(url, query=query)
assert len(resp) == 1
@ -81,7 +96,7 @@ async def test_conduct_research(mocker):
data = f"# Research Report\n## Introduction\n{args} {kwargs}"
return data
mocker.patch("metagpt.provider.base_gpt_api.BaseGPTAPI.aask", mock_llm_ask)
mocker.patch("metagpt.provider.base_llm.BaseLLM.aask", mock_llm_ask)
content = (
"MetaGPT takes a one line requirement as input and "
"outputs user stories / competitive analysis / requirements / data structures / APIs / documents, etc."
@ -103,3 +118,7 @@ async def mock_collect_links_llm_ask(self, prompt: str, system_msgs):
return "[1,2]"
return ""
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -14,13 +14,13 @@ from metagpt.schema import RunCodeContext
@pytest.mark.asyncio
async def test_run_text():
result, errs = await RunCode.run_text("result = 1 + 1")
assert result == 2
assert errs == ""
out, err = await RunCode.run_text("result = 1 + 1")
assert out == 2
assert err == ""
result, errs = await RunCode.run_text("result = 1 / 0")
assert result == ""
assert "ZeroDivisionError" in errs
out, err = await RunCode.run_text("result = 1 / 0")
assert out == ""
assert "division by zero" in err
@pytest.mark.asyncio

View file

@ -0,0 +1,87 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/9/19
@Author : mashenquan
@File : test_skill_action.py
@Desc : Unit tests.
"""
import pytest
from metagpt.actions.skill_action import ArgumentsParingAction, SkillAction
from metagpt.learn.skill_loader import Example, Parameter, Returns, Skill
class TestSkillAction:
skill = Skill(
name="text_to_image",
description="Create a drawing based on the text.",
id="text_to_image.text_to_image",
x_prerequisite={
"configurations": {
"OPENAI_API_KEY": {
"type": "string",
"description": "OpenAI API key, For more details, checkout: `https://platform.openai.com/account/api-keys`",
},
"METAGPT_TEXT_TO_IMAGE_MODEL_URL": {"type": "string", "description": "Model url."},
},
"required": {"oneOf": ["OPENAI_API_KEY", "METAGPT_TEXT_TO_IMAGE_MODEL_URL"]},
},
parameters={
"text": Parameter(type="string", description="The text used for image conversion."),
"size_type": Parameter(type="string", description="size type"),
},
examples=[
Example(ask="Draw a girl", answer='text_to_image(text="Draw a girl", size_type="512x512")'),
Example(ask="Draw an apple", answer='text_to_image(text="Draw an apple", size_type="512x512")'),
],
returns=Returns(type="string", format="base64"),
)
@pytest.mark.asyncio
async def test_parser(self):
args = ArgumentsParingAction.parse_arguments(
skill_name="text_to_image", txt='`text_to_image(text="Draw an apple", size_type="512x512")`'
)
assert args.get("text") == "Draw an apple"
assert args.get("size_type") == "512x512"
@pytest.mark.asyncio
async def test_parser_action(self):
parser_action = ArgumentsParingAction(skill=self.skill, ask="Draw an apple")
rsp = await parser_action.run()
assert rsp
assert parser_action.args
assert parser_action.args.get("text") == "Draw an apple"
assert parser_action.args.get("size_type") == "512x512"
action = SkillAction(skill=self.skill, args=parser_action.args)
rsp = await action.run()
assert rsp
assert "image/png;base64," in rsp.content or "http" in rsp.content
@pytest.mark.parametrize(
("skill_name", "txt", "want"),
[
("skill1", 'skill1(a="1", b="2")', {"a": "1", "b": "2"}),
("skill1", '(a="1", b="2")', None),
("skill1", 'skill1(a="1", b="2"', None),
],
)
def test_parse_arguments(self, skill_name, txt, want):
args = ArgumentsParingAction.parse_arguments(skill_name, txt)
assert args == want
@pytest.mark.asyncio
async def test_find_and_call_function_error(self):
with pytest.raises(ValueError):
await SkillAction.find_and_call_function("dummy_call", {"a": 1})
@pytest.mark.asyncio
async def test_skill_action_error(self):
action = SkillAction(skill=self.skill, args={})
await action.run()
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -177,6 +177,7 @@ class Snake:
@pytest.mark.asyncio
@pytest.mark.usefixtures("llm_mock")
async def test_summarize_code():
CONFIG.src_workspace = CONFIG.git_repo.workdir / "src"
await FileRepository.save_file(filename="1.json", relative_path=SYSTEM_DESIGN_FILE_REPO, content=DESIGN_CONTENT)

View file

@ -0,0 +1,52 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/12/28
@Author : mashenquan
@File : test_talk_action.py
"""
import pytest
from metagpt.actions.talk_action import TalkAction
from metagpt.config import CONFIG
from metagpt.schema import Message
@pytest.mark.asyncio
@pytest.mark.parametrize(
("agent_description", "language", "context", "knowledge", "history_summary"),
[
(
"mathematician",
"English",
"How old is Susie?",
"Susie is a girl born in 2011/11/14. Today is 2023/12/3",
"balabala... (useless words)",
),
(
"mathematician",
"Chinese",
"Does Susie have an apple?",
"Susie is a girl born in 2011/11/14. Today is 2023/12/3",
"Susie had an apple, and she ate it right now",
),
],
)
@pytest.mark.usefixtures("llm_mock")
async def test_prompt(agent_description, language, context, knowledge, history_summary):
# Prerequisites
CONFIG.agent_description = agent_description
CONFIG.language = language
action = TalkAction(context=context, knowledge=knowledge, history_summary=history_summary)
assert "{" not in action.prompt
assert "{" not in action.prompt_gpt4
rsp = await action.run()
assert rsp
assert isinstance(rsp, Message)
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -1,189 +0,0 @@
# -*- coding: utf-8 -*-
# @Date : 2023/7/22 02:40
# @Author : stellahong (stellahong@deepwisdom.ai)
#
from tests.metagpt.roles.ui_role import UIDesign
llm_resp = """
# UI Design Description
```The user interface for the snake game will be designed in a way that is simple, clean, and intuitive. The main elements of the game such as the game grid, snake, food, score, and game over message will be clearly defined and easy to understand. The game grid will be centered on the screen with the score displayed at the top. The game controls will be intuitive and easy to use. The design will be modern and minimalist with a pleasing color scheme.```
## Selected Elements
Game Grid: The game grid will be a rectangular area in the center of the screen where the game will take place. It will be defined by a border and will have a darker background color.
Snake: The snake will be represented by a series of connected blocks that move across the grid. The color of the snake will be different from the background color to make it stand out.
Food: The food will be represented by small objects that are a different color from the snake and the background. The food will be randomly placed on the grid.
Score: The score will be displayed at the top of the screen. The score will increase each time the snake eats a piece of food.
Game Over: When the game is over, a message will be displayed in the center of the screen. The player will be given the option to restart the game.
## HTML Layout
```html
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Snake Game</title>
<link rel="stylesheet" href="styles.css">
</head>
<body>
<div class="score">Score: 0</div>
<div class="game-grid">
<!-- Snake and food will be dynamically generated here using JavaScript -->
</div>
<div class="game-over">Game Over</div>
</body>
</html>
```
## CSS Styles (styles.css)
```css
body {
display: flex;
flex-direction: column;
justify-content: center;
align-items: center;
height: 100vh;
margin: 0;
background-color: #f0f0f0;
}
.score {
font-size: 2em;
margin-bottom: 1em;
}
.game-grid {
width: 400px;
height: 400px;
display: grid;
grid-template-columns: repeat(20, 1fr);
grid-template-rows: repeat(20, 1fr);
gap: 1px;
background-color: #222;
border: 1px solid #555;
}
.snake-segment {
background-color: #00cc66;
}
.food {
background-color: #cc3300;
}
.control-panel {
display: flex;
justify-content: space-around;
width: 400px;
margin-top: 1em;
}
.control-button {
padding: 1em;
font-size: 1em;
border: none;
background-color: #555;
color: #fff;
cursor: pointer;
}
.game-over {
position: absolute;
top: 50%;
left: 50%;
transform: translate(-50%, -50%);
font-size: 3em;
"""
def test_ui_design_parse_css():
ui_design_work = UIDesign(name="UI design action")
css = """
body {
display: flex;
flex-direction: column;
justify-content: center;
align-items: center;
height: 100vh;
margin: 0;
background-color: #f0f0f0;
}
.score {
font-size: 2em;
margin-bottom: 1em;
}
.game-grid {
width: 400px;
height: 400px;
display: grid;
grid-template-columns: repeat(20, 1fr);
grid-template-rows: repeat(20, 1fr);
gap: 1px;
background-color: #222;
border: 1px solid #555;
}
.snake-segment {
background-color: #00cc66;
}
.food {
background-color: #cc3300;
}
.control-panel {
display: flex;
justify-content: space-around;
width: 400px;
margin-top: 1em;
}
.control-button {
padding: 1em;
font-size: 1em;
border: none;
background-color: #555;
color: #fff;
cursor: pointer;
}
.game-over {
position: absolute;
top: 50%;
left: 50%;
transform: translate(-50%, -50%);
font-size: 3em;
"""
assert ui_design_work.parse_css_code(context=llm_resp) == css
def test_ui_design_parse_html():
ui_design_work = UIDesign(name="UI design action")
html = """
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Snake Game</title>
<link rel="stylesheet" href="styles.css">
</head>
<body>
<div class="score">Score: 0</div>
<div class="game-grid">
<!-- Snake and food will be dynamically generated here using JavaScript -->
</div>
<div class="game-over">Game Over</div>
</body>
</html>
"""
assert ui_design_work.parse_css_code(context=llm_resp) == html

View file

@ -6,25 +6,38 @@
@File : test_write_code.py
@Modifiled By: mashenquan, 2023-12-6. According to RFC 135
"""
from pathlib import Path
import pytest
from metagpt.actions.write_code import WriteCode
from metagpt.llm import LLM
from metagpt.config import CONFIG
from metagpt.const import (
CODE_SUMMARIES_FILE_REPO,
SYSTEM_DESIGN_FILE_REPO,
TASK_FILE_REPO,
TEST_OUTPUTS_FILE_REPO,
)
from metagpt.logs import logger
from metagpt.provider.openai_api import OpenAILLM as LLM
from metagpt.schema import CodingContext, Document
from tests.metagpt.actions.mock import TASKS_2, WRITE_CODE_PROMPT_SAMPLE
from metagpt.utils.common import aread
from metagpt.utils.file_repository import FileRepository
from tests.metagpt.actions.mock_markdown import TASKS_2, WRITE_CODE_PROMPT_SAMPLE
@pytest.mark.asyncio
@pytest.mark.usefixtures("llm_mock")
async def test_write_code():
context = CodingContext(
filename="task_filename.py", design_doc=Document(content="设计一个名为'add'的函数,该函数接受两个整数作为输入,并返回它们的和。")
)
doc = Document(content=context.json())
doc = Document(content=context.model_dump_json())
write_code = WriteCode(context=doc)
code = await write_code.run()
logger.info(code.json())
logger.info(code.model_dump_json())
# 我们不能精确地预测生成的代码,但我们可以检查某些关键字
assert "def add" in code.code_doc.content
@ -32,8 +45,54 @@ async def test_write_code():
@pytest.mark.asyncio
@pytest.mark.usefixtures("llm_mock")
async def test_write_code_directly():
prompt = WRITE_CODE_PROMPT_SAMPLE + "\n" + TASKS_2[0]
llm = LLM()
rsp = await llm.aask(prompt)
logger.info(rsp)
@pytest.mark.asyncio
@pytest.mark.usefixtures("llm_mock")
async def test_write_code_deps():
# Prerequisites
CONFIG.src_workspace = CONFIG.git_repo.workdir / "snake1/snake1"
demo_path = Path(__file__).parent / "../../data/demo_project"
await FileRepository.save_file(
filename="test_game.py.json",
content=await aread(str(demo_path / "test_game.py.json")),
relative_path=TEST_OUTPUTS_FILE_REPO,
)
await FileRepository.save_file(
filename="20231221155954.json",
content=await aread(str(demo_path / "code_summaries.json")),
relative_path=CODE_SUMMARIES_FILE_REPO,
)
await FileRepository.save_file(
filename="20231221155954.json",
content=await aread(str(demo_path / "system_design.json")),
relative_path=SYSTEM_DESIGN_FILE_REPO,
)
await FileRepository.save_file(
filename="20231221155954.json", content=await aread(str(demo_path / "tasks.json")), relative_path=TASK_FILE_REPO
)
await FileRepository.save_file(
filename="main.py", content='if __name__ == "__main__":\nmain()', relative_path=CONFIG.src_workspace
)
context = CodingContext(
filename="game.py",
design_doc=await FileRepository.get_file(filename="20231221155954.json", relative_path=SYSTEM_DESIGN_FILE_REPO),
task_doc=await FileRepository.get_file(filename="20231221155954.json", relative_path=TASK_FILE_REPO),
code_doc=Document(filename="game.py", content="", root_path="snake1"),
)
coding_doc = Document(root_path="snake1", filename="game.py", content=context.json())
action = WriteCode(context=coding_doc)
rsp = await action.run()
assert rsp
assert rsp.code_doc.content
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -8,11 +8,11 @@
import pytest
from metagpt.actions.write_code_review import WriteCodeReview
from metagpt.document import Document
from metagpt.schema import CodingContext
from metagpt.schema import CodingContext, Document
@pytest.mark.asyncio
@pytest.mark.usefixtures("llm_mock")
async def test_write_code_review(capfd):
code = """
def add(a, b):

View file

@ -27,6 +27,18 @@ class Person:
],
ids=["google", "numpy", "sphinx"],
)
@pytest.mark.usefixtures("llm_mock")
async def test_write_docstring(style: str, part: str):
ret = await WriteDocstring().run(code, style=style)
assert part in ret
@pytest.mark.asyncio
@pytest.mark.usefixtures("llm_mock")
async def test_write():
code = await WriteDocstring.write_docstring(__file__)
assert code
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -18,6 +18,7 @@ from metagpt.utils.file_repository import FileRepository
@pytest.mark.asyncio
@pytest.mark.usefixtures("llm_mock")
async def test_write_prd():
product_manager = ProductManager()
requirements = "开发一个基于大语言模型与私有知识库的搜索引擎,希望可以基于大语言模型进行搜索总结"

View file

@ -11,6 +11,7 @@ from metagpt.actions.write_prd_review import WritePRDReview
@pytest.mark.asyncio
@pytest.mark.usefixtures("llm_mock")
async def test_write_prd_review():
prd = """
Introduction: This is a new feature for our product.
@ -23,10 +24,14 @@ async def test_write_prd_review():
Timeline: The feature should be ready for testing in 1.5 months.
"""
write_prd_review = WritePRDReview("write_prd_review")
write_prd_review = WritePRDReview(name="write_prd_review")
prd_review = await write_prd_review.run(prd)
# We cannot exactly predict the generated PRD review, but we can check if it is a string and if it is not empty
assert isinstance(prd_review, str)
assert len(prd_review) > 0
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -46,6 +46,7 @@ CONTEXT = """
@pytest.mark.asyncio
@pytest.mark.usefixtures("llm_mock")
async def test_write_review():
write_review = WriteReview()
review = await write_review.run(CONTEXT)

View file

@ -0,0 +1,27 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/7/28 17:25
@Author : mashenquan
@File : test_write_teaching_plan.py
"""
import pytest
from metagpt.actions.write_teaching_plan import WriteTeachingPlanPart
@pytest.mark.asyncio
@pytest.mark.parametrize(
("topic", "context"),
[("Title", "Lesson 1: Learn to draw an apple."), ("Teaching Content", "Lesson 1: Learn to draw an apple.")],
)
@pytest.mark.usefixtures("llm_mock")
async def test_write_teaching_plan_part(topic, context):
action = WriteTeachingPlanPart(topic=topic, context=context)
rsp = await action.run()
assert rsp
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -13,6 +13,7 @@ from metagpt.schema import Document, TestingContext
@pytest.mark.asyncio
@pytest.mark.usefixtures("llm_mock")
async def test_write_test():
code = """
import random
@ -29,7 +30,7 @@ async def test_write_test():
write_test = WriteTest(context=context)
context = await write_test.run()
logger.info(context.json())
logger.info(context.model_dump_json())
# We cannot exactly predict the generated test cases, but we can check if it is a string and if it is not empty
assert isinstance(context.test_doc.content, str)
@ -39,6 +40,7 @@ async def test_write_test():
@pytest.mark.asyncio
@pytest.mark.usefixtures("llm_mock")
async def test_write_code_invalid_code(mocker):
# Mock the _aask method to return an invalid code string
mocker.patch.object(WriteTest, "_aask", return_value="Invalid Code String")
@ -51,3 +53,7 @@ async def test_write_code_invalid_code(mocker):
# Assert that the returned code is the same as the invalid code string
assert code == "Invalid Code String"
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -14,6 +14,7 @@ from metagpt.actions.write_tutorial import WriteContent, WriteDirectory
@pytest.mark.asyncio
@pytest.mark.parametrize(("language", "topic"), [("English", "Write a tutorial about Python")])
@pytest.mark.usefixtures("llm_mock")
async def test_write_directory(language: str, topic: str):
ret = await WriteDirectory(language=language).run(topic=topic)
assert isinstance(ret, dict)
@ -29,6 +30,7 @@ async def test_write_directory(language: str, topic: str):
("language", "topic", "directory"),
[("English", "Write a tutorial about Python", {"Introduction": ["What is Python?", "Why learn Python?"]})],
)
@pytest.mark.usefixtures("llm_mock")
async def test_write_content(language: str, topic: str, directory: Dict):
ret = await WriteContent(language=language, directory=directory).run(topic=topic)
assert isinstance(ret, str)

View file

@ -5,73 +5,36 @@
@Author : alexanderwu
@File : test_faiss_store.py
"""
import functools
import pytest
from metagpt.const import DATA_PATH
from metagpt.const import EXAMPLE_PATH
from metagpt.document_store import FaissStore
from metagpt.roles import CustomerService, Sales
DESC = """## 原则(所有事情都不可绕过原则)
1. 你是一位平台的人工客服话语精炼一次只说一句话会参考规则与FAQ进行回复在与顾客交谈中绝不允许暴露规则与相关字样
2. 在遇到问题时先尝试仅安抚顾客情绪如果顾客情绪十分不好再考虑赔偿如果赔偿的过多你会被开除
3. 绝不要向顾客做虚假承诺不要提及其他人的信息
## 技能(在回答尾部,加入`skill(args)`就可以使用技能)
1. 查询订单问顾客手机号是获得订单的唯一方式获得手机号后使用`find_order(手机号)`来获得订单
2. 退款输出关键词 `refund(手机号)`系统会自动退款
3. 开箱需要手机号确认顾客在柜前如果需要开箱输出指令 `open_box(手机号)`系统会自动开箱
### 使用技能例子
user: 你好收不到取餐码
小爽人工: 您好请提供一下手机号
user: 14750187158
小爽人工: 好的为您查询一下订单您已经在柜前了吗`find_order(14750187158)`
user: 是的
小爽人工: 您看下开了没有`open_box(14750187158)`
user: 开了谢谢
小爽人工: 好的还有什么可以帮到您吗
user: 没有了
小爽人工: 祝您生活愉快
"""
from metagpt.logs import logger
from metagpt.roles import Sales
@pytest.mark.asyncio
async def test_faiss_store_search():
store = FaissStore(DATA_PATH / "qcs/qcs_4w.json")
store.add(["油皮洗面奶"])
role = Sales(store=store)
queries = ["油皮洗面奶", "介绍下欧莱雅的"]
for query in queries:
rsp = await role.run(query)
assert rsp
def customer_service():
store = FaissStore(DATA_PATH / "st/faq.xlsx", content_col="Question", meta_col="Answer")
store.search = functools.partial(store.search, expand_cols=True)
role = CustomerService(profile="小爽人工", desc=DESC, store=store)
return role
async def test_search_json():
store = FaissStore(EXAMPLE_PATH / "example.json")
role = Sales(profile="Sales", store=store)
query = "Which facial cleanser is good for oily skin?"
result = await role.run(query)
logger.info(result)
@pytest.mark.asyncio
async def test_faiss_store_customer_service():
allq = [
# ["我的餐怎么两小时都没到", "退货吧"],
[
"你好收不到取餐码,麻烦帮我开箱",
"14750187158",
]
]
role = customer_service()
for queries in allq:
for query in queries:
rsp = await role.run(query)
assert rsp
async def test_search_xlsx():
store = FaissStore(EXAMPLE_PATH / "example.xlsx")
role = Sales(profile="Sales", store=store)
query = "Which facial cleanser is good for oily skin?"
result = await role.run(query)
logger.info(result)
def test_faiss_store_no_file():
with pytest.raises(FileNotFoundError):
FaissStore(DATA_PATH / "wtf.json")
@pytest.mark.asyncio
async def test_write():
store = FaissStore(EXAMPLE_PATH / "example.xlsx", meta_col="Answer", content_col="Question")
_faiss_store = store.write()
assert _faiss_store.docstore
assert _faiss_store.index

View file

@ -7,12 +7,9 @@
"""
import random
import pytest
from metagpt.document_store.lancedb_store import LanceStore
@pytest
def test_lance_store():
# This simply establishes the connection to the database, so we can drop the table if it exists
store = LanceStore("test")

View file

@ -1,36 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/6/11 21:08
@Author : alexanderwu
@File : test_milvus_store.py
"""
import random
import numpy as np
from metagpt.document_store.milvus_store import MilvusConnection, MilvusStore
from metagpt.logs import logger
book_columns = {"idx": int, "name": str, "desc": str, "emb": np.ndarray, "price": float}
book_data = [
[i for i in range(10)],
[f"book-{i}" for i in range(10)],
[f"book-desc-{i}" for i in range(10000, 10010)],
[[random.random() for _ in range(2)] for _ in range(10)],
[random.random() for _ in range(10)],
]
def test_milvus_store():
milvus_connection = MilvusConnection(alias="default", host="192.168.50.161", port="30530")
milvus_store = MilvusStore(milvus_connection)
milvus_store.drop("Book")
milvus_store.create_collection("Book", book_columns)
milvus_store.add(book_data)
milvus_store.build_index("emb")
milvus_store.load_collection()
results = milvus_store.search([[1.0, 1.0]], field="emb")
logger.info(results)
assert results

View file

@ -29,7 +29,7 @@ points = [
]
def test_milvus_store():
def test_qdrant_store():
qdrant_connection = QdrantConnection(memory=True)
vectors_config = VectorParams(size=2, distance=Distance.COSINE)
qdrant_store = QdrantStore(qdrant_connection)
@ -43,13 +43,13 @@ def test_milvus_store():
results = qdrant_store.search("Book", query=[1.0, 1.0])
assert results[0]["id"] == 2
assert results[0]["score"] == 0.999106722578389
assert results[1]["score"] == 7
assert results[1]["id"] == 7
assert results[1]["score"] == 0.9961650411397226
results = qdrant_store.search("Book", query=[1.0, 1.0], return_vector=True)
assert results[0]["id"] == 2
assert results[0]["score"] == 0.999106722578389
assert results[0]["vector"] == [0.7363563179969788, 0.6765939593315125]
assert results[1]["score"] == 7
assert results[1]["id"] == 7
assert results[1]["score"] == 0.9961650411397226
assert results[1]["vector"] == [0.7662628889083862, 0.6425272226333618]
results = qdrant_store.search(

View file

View file

@ -0,0 +1,27 @@
import asyncio
from pydantic import BaseModel
from metagpt.learn.google_search import google_search
async def mock_google_search():
class Input(BaseModel):
input: str
inputs = [{"input": "ai agent"}]
for i in inputs:
seed = Input(**i)
result = await google_search(seed.input)
assert result != ""
def test_suite():
loop = asyncio.get_event_loop()
task = loop.create_task(mock_google_search())
loop.run_until_complete(task)
if __name__ == "__main__":
test_suite()

View file

@ -0,0 +1,43 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/9/19
@Author : mashenquan
@File : test_skill_loader.py
@Desc : Unit tests.
"""
import pytest
from metagpt.config import CONFIG
from metagpt.learn.skill_loader import SkillsDeclaration
@pytest.mark.asyncio
async def test_suite():
CONFIG.agent_skills = [
{"id": 1, "name": "text_to_speech", "type": "builtin", "config": {}, "enabled": True},
{"id": 2, "name": "text_to_image", "type": "builtin", "config": {}, "enabled": True},
{"id": 3, "name": "ai_call", "type": "builtin", "config": {}, "enabled": True},
{"id": 3, "name": "data_analysis", "type": "builtin", "config": {}, "enabled": True},
{"id": 5, "name": "crawler", "type": "builtin", "config": {"engine": "ddg"}, "enabled": True},
{"id": 6, "name": "knowledge", "type": "builtin", "config": {}, "enabled": True},
{"id": 6, "name": "web_search", "type": "builtin", "config": {}, "enabled": True},
]
loader = await SkillsDeclaration.load()
skills = loader.get_skill_list()
assert skills
assert len(skills) >= 3
for desc, name in skills.items():
assert desc
assert name
entity = loader.entities.get("Assistant")
assert entity
assert entity.skills
for sk in entity.skills:
assert sk
assert sk.arguments
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -0,0 +1,26 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/8/18
@Author : mashenquan
@File : test_text_to_embedding.py
@Desc : Unit tests.
"""
import pytest
from metagpt.config import CONFIG
from metagpt.learn.text_to_embedding import text_to_embedding
@pytest.mark.asyncio
async def test_text_to_embedding():
# Prerequisites
assert CONFIG.OPENAI_API_KEY
v = await text_to_embedding(text="Panda emoji")
assert len(v.data) > 0
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -0,0 +1,39 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/8/18
@Author : mashenquan
@File : test_text_to_image.py
@Desc : Unit tests.
"""
import pytest
from metagpt.config import CONFIG
from metagpt.learn.text_to_image import text_to_image
@pytest.mark.asyncio
async def test_metagpt_llm():
# Prerequisites
assert CONFIG.METAGPT_TEXT_TO_IMAGE_MODEL_URL
assert CONFIG.OPENAI_API_KEY
data = await text_to_image("Panda emoji", size_type="512x512")
assert "base64" in data or "http" in data
# Mock session env
old_options = CONFIG.options.copy()
new_options = old_options.copy()
new_options["METAGPT_TEXT_TO_IMAGE_MODEL_URL"] = None
CONFIG.set_context(new_options)
try:
data = await text_to_image("Panda emoji", size_type="512x512")
assert "base64" in data or "http" in data
finally:
CONFIG.set_context(old_options)
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -0,0 +1,43 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/8/18
@Author : mashenquan
@File : test_text_to_speech.py
@Desc : Unit tests.
"""
import pytest
from metagpt.config import CONFIG
from metagpt.learn.text_to_speech import text_to_speech
@pytest.mark.asyncio
async def test_text_to_speech():
# Prerequisites
assert CONFIG.IFLYTEK_APP_ID
assert CONFIG.IFLYTEK_API_KEY
assert CONFIG.IFLYTEK_API_SECRET
assert CONFIG.AZURE_TTS_SUBSCRIPTION_KEY and CONFIG.AZURE_TTS_SUBSCRIPTION_KEY != "YOUR_API_KEY"
assert CONFIG.AZURE_TTS_REGION
# test azure
data = await text_to_speech("panda emoji")
assert "base64" in data or "http" in data
# test iflytek
## Mock session env
old_options = CONFIG.options.copy()
new_options = old_options.copy()
new_options["AZURE_TTS_SUBSCRIPTION_KEY"] = ""
CONFIG.set_context(new_options)
try:
data = await text_to_speech("panda emoji")
assert "base64" in data or "http" in data
finally:
CONFIG.set_context(old_options)
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -14,9 +14,9 @@ def test_skill_manager():
manager = SkillManager()
logger.info(manager._store)
write_prd = WritePRD("WritePRD")
write_prd = WritePRD(name="WritePRD")
write_prd.desc = "基于老板或其他人的需求进行PRD的撰写包括用户故事、需求分解等"
write_test = WriteTest("WriteTest")
write_test = WriteTest(name="WriteTest")
write_test.desc = "进行测试用例的撰写"
manager.add_skill(write_prd)
manager.add_skill(write_test)
@ -24,7 +24,7 @@ def test_skill_manager():
skill = manager.get_skill("WriteTest")
logger.info(skill)
rsp = manager.retrieve_skill("PRD")
rsp = manager.retrieve_skill("WritePRD")
logger.info(rsp)
assert rsp[0] == "WritePRD"

View file

@ -0,0 +1,71 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/8/27
@Author : mashenquan
@File : test_brain_memory.py
"""
import pytest
from metagpt.config import LLMProviderEnum
from metagpt.llm import LLM
from metagpt.memory.brain_memory import BrainMemory
from metagpt.schema import Message
@pytest.mark.asyncio
async def test_memory():
memory = BrainMemory()
memory.add_talk(Message(content="talk"))
assert memory.history[0].role == "user"
memory.add_answer(Message(content="answer"))
assert memory.history[1].role == "assistant"
redis_key = BrainMemory.to_redis_key("none", "user_id", "chat_id")
await memory.dumps(redis_key=redis_key)
assert memory.exists("talk")
assert 1 == memory.to_int("1", 0)
memory.last_talk = "AAA"
assert memory.pop_last_talk() == "AAA"
assert memory.last_talk is None
assert memory.is_history_available
assert memory.history_text
memory = await BrainMemory.loads(redis_key=redis_key)
assert memory
@pytest.mark.parametrize(
("input", "tag", "val"),
[("[TALK]:Hello", "TALK", "Hello"), ("Hello", None, "Hello"), ("[TALK]Hello", None, "[TALK]Hello")],
)
def test_extract_info(input, tag, val):
t, v = BrainMemory.extract_info(input)
assert tag == t
assert val == v
@pytest.mark.asyncio
@pytest.mark.parametrize("llm", [LLM(provider=LLMProviderEnum.OPENAI), LLM(provider=LLMProviderEnum.METAGPT)])
async def test_memory_llm(llm):
memory = BrainMemory()
for i in range(500):
memory.add_talk(Message(content="Lily is a girl.\n"))
res = await memory.is_related("apple", "moon", llm)
assert not res
res = await memory.rewrite(sentence="apple Lily eating", context="", llm=llm)
assert "Lily" in res
res = await memory.summarize(llm=llm)
assert res
res = await memory.get_title(llm=llm)
assert res
assert "Lily" in res
assert memory.history or memory.historical_summary
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -2,22 +2,31 @@
# -*- coding: utf-8 -*-
"""
@Desc : unittest of `metagpt/memory/longterm_memory.py`
@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation.
"""
import os
import pytest
from metagpt.actions import UserRequirement
from metagpt.config import CONFIG
from metagpt.memory import LongTermMemory
from metagpt.memory.longterm_memory import LongTermMemory
from metagpt.roles.role import RoleContext
from metagpt.schema import Message
def test_ltm_search():
assert hasattr(CONFIG, "long_term_memory") is True
openai_api_key = CONFIG.openai_api_key
assert len(openai_api_key) > 20
os.environ.setdefault("OPENAI_API_KEY", CONFIG.openai_api_key)
assert len(CONFIG.openai_api_key) > 20
role_id = "UTUserLtm(Product Manager)"
rc = RoleContext(watch=[UserRequirement])
from metagpt.environment import Environment
Environment
RoleContext.model_rebuild()
rc = RoleContext(watch={"metagpt.actions.add_requirement.UserRequirement"})
ltm = LongTermMemory()
ltm.recover_memory(role_id, rc)
@ -28,6 +37,7 @@ def test_ltm_search():
ltm.add(message)
sim_idea = "Write a game of cli snake"
sim_message = Message(role="User", content=sim_idea, cause_by=UserRequirement)
news = ltm.find_news([sim_message])
assert len(news) == 0
@ -55,3 +65,7 @@ def test_ltm_search():
assert len(news) == 1
ltm_new.clear()
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -0,0 +1,57 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : the unittest of Memory
from metagpt.actions import UserRequirement
from metagpt.memory.memory import Memory
from metagpt.schema import Message
def test_memory():
memory = Memory()
message1 = Message(content="test message1", role="user1")
message2 = Message(content="test message2", role="user2")
message3 = Message(content="test message3", role="user1")
memory.add(message1)
assert memory.count() == 1
memory.delete_newest()
assert memory.count() == 0
memory.add_batch([message1, message2])
assert memory.count() == 2
assert len(memory.index.get(message1.cause_by)) == 2
messages = memory.get_by_role("user1")
assert messages[0].content == message1.content
messages = memory.get_by_content("test message")
assert len(messages) == 2
messages = memory.get_by_action(UserRequirement)
assert len(messages) == 2
messages = memory.get_by_actions([UserRequirement])
assert len(messages) == 2
messages = memory.try_remember("test message")
assert len(messages) == 2
messages = memory.get(k=1)
assert len(messages) == 1
messages = memory.get(k=5)
assert len(messages) == 2
messages = memory.find_news([message3])
assert len(messages) == 1
memory.delete(message1)
assert memory.count() == 1
messages = memory.get_by_role("user2")
assert messages[0].content == message2.content
memory.clear()
assert memory.count() == 0
assert len(memory.index) == 0

View file

@ -4,20 +4,28 @@
@Desc : the unittests of metagpt/memory/memory_storage.py
"""
import os
import shutil
from pathlib import Path
from typing import List
from metagpt.actions import UserRequirement, WritePRD
from metagpt.actions.action_node import ActionNode
from metagpt.config import CONFIG
from metagpt.const import DATA_PATH
from metagpt.memory.memory_storage import MemoryStorage
from metagpt.schema import Message
os.environ.setdefault("OPENAI_API_KEY", CONFIG.openai_api_key)
def test_idea_message():
idea = "Write a cli snake game"
role_id = "UTUser1(Product Manager)"
message = Message(role="User", content=idea, cause_by=UserRequirement)
shutil.rmtree(Path(DATA_PATH / f"role_mem/{role_id}/"), ignore_errors=True)
memory_storage: MemoryStorage = MemoryStorage()
messages = memory_storage.recover_memory(role_id)
assert len(messages) == 0
@ -27,12 +35,12 @@ def test_idea_message():
sim_idea = "Write a game of cli snake"
sim_message = Message(role="User", content=sim_idea, cause_by=UserRequirement)
new_messages = memory_storage.search(sim_message)
new_messages = memory_storage.search_dissimilar(sim_message)
assert len(new_messages) == 0 # similar, return []
new_idea = "Write a 2048 web game"
new_message = Message(role="User", content=new_idea, cause_by=UserRequirement)
new_messages = memory_storage.search(new_message)
new_messages = memory_storage.search_dissimilar(new_message)
assert new_messages[0].content == message.content
memory_storage.clean()
@ -50,6 +58,8 @@ def test_actionout_message():
content=content, instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD
) # WritePRD as test action
shutil.rmtree(Path(DATA_PATH / f"role_mem/{role_id}/"), ignore_errors=True)
memory_storage: MemoryStorage = MemoryStorage()
messages = memory_storage.recover_memory(role_id)
assert len(messages) == 0
@ -59,12 +69,12 @@ def test_actionout_message():
sim_conent = "The request is command-line interface (CLI) snake game"
sim_message = Message(content=sim_conent, instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD)
new_messages = memory_storage.search(sim_message)
new_messages = memory_storage.search_dissimilar(sim_message)
assert len(new_messages) == 0 # similar, return []
new_conent = "Incorporate basic features of a snake game such as scoring and increasing difficulty"
new_message = Message(content=new_conent, instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD)
new_messages = memory_storage.search(new_message)
new_messages = memory_storage.search_dissimilar(new_message)
assert new_messages[0].content == message.content
memory_storage.clean()

View file

@ -0,0 +1,3 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc :

View file

@ -0,0 +1,31 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc :
from metagpt.provider.postprocess.base_postprocess_plugin import BasePostProcessPlugin
raw_output = """
[CONTENT]
{
"Original Requirements": "xxx"
}
[/CONTENT]
"""
raw_schema = {
"title": "prd",
"type": "object",
"properties": {
"Original Requirements": {"title": "Original Requirements", "type": "string"},
},
"required": [
"Original Requirements",
],
}
def test_llm_post_process_plugin():
post_process_plugin = BasePostProcessPlugin()
output = post_process_plugin.run(output=raw_output, schema=raw_schema)
assert "Original Requirements" in output

View file

@ -0,0 +1,15 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc :
from metagpt.provider.postprocess.llm_output_postprocess import llm_output_postprocess
from tests.metagpt.provider.postprocess.test_base_postprocess_plugin import (
raw_output,
raw_schema,
)
def test_llm_output_postprocess():
output = llm_output_postprocess(output=raw_output, schema=raw_schema)
assert "Original Requirements" in output

View file

@ -0,0 +1,34 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : the unittest of Claude2
import pytest
from anthropic.resources.completions import Completion
from metagpt.config import CONFIG
from metagpt.provider.anthropic_api import Claude2
CONFIG.anthropic_api_key = "xxx"
prompt = "who are you"
resp = "I'am Claude2"
def mock_anthropic_completions_create(self, model: str, prompt: str, max_tokens_to_sample: int) -> Completion:
return Completion(id="xx", completion=resp, model="claude-2", stop_reason="stop_sequence", type="completion")
async def mock_anthropic_acompletions_create(self, model: str, prompt: str, max_tokens_to_sample: int) -> Completion:
return Completion(id="xx", completion=resp, model="claude-2", stop_reason="stop_sequence", type="completion")
def test_claude2_ask(mocker):
mocker.patch("anthropic.resources.completions.Completions.create", mock_anthropic_completions_create)
assert resp == Claude2().ask(prompt)
@pytest.mark.asyncio
async def test_claude2_aask(mocker):
mocker.patch("anthropic.resources.completions.AsyncCompletions.create", mock_anthropic_acompletions_create)
assert resp == await Claude2().aask(prompt)

View file

@ -0,0 +1,14 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc :
from metagpt.config import CONFIG
from metagpt.provider.azure_openai_api import AzureOpenAILLM
CONFIG.OPENAI_API_VERSION = "xx"
CONFIG.openai_proxy = "http://127.0.0.1:80" # fake value
def test_azure_openai_api():
_ = AzureOpenAILLM()

View file

@ -3,13 +3,104 @@
"""
@Time : 2023/5/7 17:40
@Author : alexanderwu
@File : test_base_gpt_api.py
@File : test_base_llm.py
"""
import pytest
from metagpt.provider.base_llm import BaseLLM
from metagpt.schema import Message
default_chat_resp = {
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "I'am GPT",
},
"finish_reason": "stop",
}
]
}
prompt_msg = "who are you"
resp_content = default_chat_resp["choices"][0]["message"]["content"]
def test_message():
message = Message(role="user", content="wtf")
class MockBaseLLM(BaseLLM):
def completion(self, messages: list[dict], timeout=3):
return default_chat_resp
async def acompletion(self, messages: list[dict], timeout=3):
return default_chat_resp
async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str:
return resp_content
async def close(self):
return default_chat_resp
def test_base_llm():
message = Message(role="user", content="hello")
assert "role" in message.to_dict()
assert "user" in str(message)
base_llm = MockBaseLLM()
openai_funccall_resp = {
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "test",
"tool_calls": [
{
"id": "call_Y5r6Ddr2Qc2ZrqgfwzPX5l72",
"type": "function",
"function": {
"name": "execute",
"arguments": '{\n "language": "python",\n "code": "print(\'Hello, World!\')"\n}',
},
}
],
},
"finish_reason": "stop",
}
]
}
func: dict = base_llm.get_choice_function(openai_funccall_resp)
assert func == {
"name": "execute",
"arguments": '{\n "language": "python",\n "code": "print(\'Hello, World!\')"\n}',
}
func_args: dict = base_llm.get_choice_function_arguments(openai_funccall_resp)
assert func_args == {"language": "python", "code": "print('Hello, World!')"}
choice_text = base_llm.get_choice_text(openai_funccall_resp)
assert choice_text == openai_funccall_resp["choices"][0]["message"]["content"]
# resp = base_llm.ask(prompt_msg)
# assert resp == resp_content
# resp = base_llm.ask_batch([prompt_msg])
# assert resp == resp_content
# resp = base_llm.ask_code([prompt_msg])
# assert resp == resp_content
@pytest.mark.asyncio
async def test_async_base_llm():
base_llm = MockBaseLLM()
resp = await base_llm.aask(prompt_msg)
assert resp == resp_content
resp = await base_llm.aask_batch([prompt_msg])
assert resp == resp_content
resp = await base_llm.aask_code([prompt_msg])
assert resp == resp_content

View file

@ -0,0 +1,118 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : the unittest of fireworks api
import pytest
from openai.types.chat.chat_completion import (
ChatCompletion,
ChatCompletionMessage,
Choice,
)
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from openai.types.chat.chat_completion_chunk import Choice as AChoice
from openai.types.chat.chat_completion_chunk import ChoiceDelta
from openai.types.completion_usage import CompletionUsage
from metagpt.config import CONFIG
from metagpt.provider.fireworks_api import (
MODEL_GRADE_TOKEN_COSTS,
FireworksCostManager,
FireworksLLM,
)
from metagpt.utils.cost_manager import Costs
CONFIG.fireworks_api_key = "xxx"
CONFIG.max_budget = 10
CONFIG.calc_usage = True
resp_content = "I'm fireworks"
default_resp = ChatCompletion(
id="cmpl-a6652c1bb181caae8dd19ad8",
model="accounts/fireworks/models/llama-v2-13b-chat",
object="chat.completion",
created=1703300855,
choices=[
Choice(
finish_reason="stop",
index=0,
message=ChatCompletionMessage(role="assistant", content=resp_content),
logprobs=None,
)
],
usage=CompletionUsage(completion_tokens=110, prompt_tokens=92, total_tokens=202),
)
default_resp_chunk = ChatCompletionChunk(
id=default_resp.id,
model=default_resp.model,
object="chat.completion.chunk",
created=default_resp.created,
choices=[
AChoice(
delta=ChoiceDelta(content=resp_content, role="assistant"),
finish_reason="stop",
index=0,
logprobs=None,
)
],
usage=dict(default_resp.usage),
)
prompt_msg = "who are you"
messages = [{"role": "user", "content": prompt_msg}]
def test_fireworks_costmanager():
cost_manager = FireworksCostManager()
assert MODEL_GRADE_TOKEN_COSTS["-1"] == cost_manager.model_grade_token_costs("test")
assert MODEL_GRADE_TOKEN_COSTS["-1"] == cost_manager.model_grade_token_costs("xxx-81b-chat")
assert MODEL_GRADE_TOKEN_COSTS["16"] == cost_manager.model_grade_token_costs("llama-v2-13b-chat")
assert MODEL_GRADE_TOKEN_COSTS["16"] == cost_manager.model_grade_token_costs("xxx-15.5b-chat")
assert MODEL_GRADE_TOKEN_COSTS["16"] == cost_manager.model_grade_token_costs("xxx-16b-chat")
assert MODEL_GRADE_TOKEN_COSTS["80"] == cost_manager.model_grade_token_costs("xxx-80b-chat")
assert MODEL_GRADE_TOKEN_COSTS["mixtral-8x7b"] == cost_manager.model_grade_token_costs("mixtral-8x7b-chat")
cost_manager.update_cost(prompt_tokens=500000, completion_tokens=500000, model="llama-v2-13b-chat")
assert cost_manager.total_cost == 0.5
async def mock_openai_acompletions_create(self, stream: bool = False, **kwargs) -> ChatCompletionChunk:
if stream:
class Iterator(object):
async def __aiter__(self):
yield default_resp_chunk
return Iterator()
else:
return default_resp
@pytest.mark.asyncio
async def test_fireworks_acompletion(mocker):
mocker.patch("openai.resources.chat.completions.AsyncCompletions.create", mock_openai_acompletions_create)
fireworks_gpt = FireworksLLM()
fireworks_gpt.model = "llama-v2-13b-chat"
fireworks_gpt._update_costs(
usage=CompletionUsage(prompt_tokens=500000, completion_tokens=500000, total_tokens=1000000)
)
assert fireworks_gpt.get_costs() == Costs(
total_prompt_tokens=500000, total_completion_tokens=500000, total_cost=0.5, total_budget=0
)
resp = await fireworks_gpt.acompletion(messages)
assert resp.choices[0].message.content in resp_content
resp = await fireworks_gpt.aask(prompt_msg, stream=False)
assert resp == resp_content
resp = await fireworks_gpt.acompletion_text(messages, stream=False)
assert resp == resp_content
resp = await fireworks_gpt.acompletion_text(messages, stream=True)
assert resp == resp_content
resp = await fireworks_gpt.aask(prompt_msg)
assert resp == resp_content

View file

@ -0,0 +1,136 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc :
import os
from typing import AsyncGenerator, Generator, Iterator, Tuple, Union
import aiohttp
import pytest
import requests
from openai import OpenAIError
from metagpt.provider.general_api_base import (
APIRequestor,
ApiType,
OpenAIResponse,
_aiohttp_proxies_arg,
_build_api_url,
_make_session,
_requests_proxies_arg,
log_debug,
log_info,
log_warn,
logfmt,
parse_stream,
parse_stream_helper,
)
def test_basic():
_ = ApiType.from_str("azure")
_ = ApiType.from_str("azuread")
_ = ApiType.from_str("openai")
with pytest.raises(OpenAIError):
_ = ApiType.from_str("xx")
os.environ.setdefault("LLM_LOG", "debug")
log_debug("debug")
log_warn("warn")
log_info("info")
logfmt({"k1": b"v1", "k2": 1, "k3": "a b"})
_build_api_url(url="http://www.baidu.com/s?wd=", query="baidu")
def test_openai_response():
resp = OpenAIResponse(data=[], headers={"retry-after": 3})
assert resp.request_id is None
assert resp.retry_after == 3
assert resp.operation_location is None
assert resp.organization is None
assert resp.response_ms is None
def test_proxy():
assert _requests_proxies_arg(proxy=None) is None
proxy = "127.0.0.1:80"
assert _requests_proxies_arg(proxy=proxy) == {"http": proxy, "https": proxy}
proxy_dict = {"http": proxy}
assert _requests_proxies_arg(proxy=proxy_dict) == proxy_dict
assert _aiohttp_proxies_arg(proxy_dict) == proxy
proxy_dict = {"https": proxy}
assert _requests_proxies_arg(proxy=proxy_dict) == proxy_dict
assert _aiohttp_proxies_arg(proxy_dict) == proxy
assert _make_session() is not None
assert _aiohttp_proxies_arg(None) is None
assert _aiohttp_proxies_arg("test") == "test"
with pytest.raises(ValueError):
_aiohttp_proxies_arg(-1)
def test_parse_stream():
assert parse_stream_helper(None) is None
assert parse_stream_helper(b"data: [DONE]") is None
assert parse_stream_helper(b"data: test") == "test"
assert parse_stream_helper(b"test") is None
for line in parse_stream([b"data: test"]):
assert line == "test"
api_requestor = APIRequestor(base_url="http://www.baidu.com")
def mock_interpret_response(
self, result: requests.Response, stream: bool
) -> Tuple[Union[bytes, Iterator[Generator]], bytes]:
return b"baidu", False
async def mock_interpret_async_response(
self, result: aiohttp.ClientResponse, stream: bool
) -> Tuple[Union[OpenAIResponse, AsyncGenerator[OpenAIResponse, None]], bool]:
return b"baidu", True
def test_requestor_headers():
# validate_headers
headers = api_requestor._validate_headers(None)
assert not headers
with pytest.raises(Exception):
api_requestor._validate_headers(-1)
with pytest.raises(Exception):
api_requestor._validate_headers({1: 2})
with pytest.raises(Exception):
api_requestor._validate_headers({"test": 1})
supplied_headers = {"test": "test"}
assert api_requestor._validate_headers(supplied_headers) == supplied_headers
api_requestor.organization = "test"
api_requestor.api_version = "test123"
api_requestor.api_type = ApiType.OPEN_AI
request_id = "test123"
headers = api_requestor.request_headers(method="post", extra={}, request_id=request_id)
assert headers["LLM-Organization"] == api_requestor.organization
assert headers["LLM-Version"] == api_requestor.api_version
assert headers["X-Request-Id"] == request_id
def test_api_requestor(mocker):
mocker.patch("metagpt.provider.general_api_base.APIRequestor._interpret_response", mock_interpret_response)
resp, _, _ = api_requestor.request(method="get", url="/s?wd=baidu")
resp, _, _ = api_requestor.request(method="post", url="/s?wd=baidu")
@pytest.mark.asyncio
async def test_async_api_requestor(mocker):
mocker.patch(
"metagpt.provider.general_api_base.APIRequestor._interpret_async_response", mock_interpret_async_response
)
resp, _, _ = await api_requestor.arequest(method="get", url="/s?wd=baidu")
resp, _, _ = await api_requestor.arequest(method="post", url="/s?wd=baidu")

View file

@ -0,0 +1,33 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : the unittest of APIRequestor
import pytest
from metagpt.provider.general_api_requestor import (
GeneralAPIRequestor,
parse_stream,
parse_stream_helper,
)
api_requestor = GeneralAPIRequestor(base_url="http://www.baidu.com")
def test_parse_stream():
assert parse_stream_helper(None) is None
assert parse_stream_helper(b"data: [DONE]") is None
assert parse_stream_helper(b"data: test") == b"test"
assert parse_stream_helper(b"test") is None
for line in parse_stream([b"data: test"]):
assert line == b"test"
def test_api_requestor():
resp, _, _ = api_requestor.request(method="get", url="/s?wd=baidu")
assert b"baidu" in resp
@pytest.mark.asyncio
async def test_async_api_requestor():
resp, _, _ = await api_requestor.arequest(method="get", url="/s?wd=baidu")
assert b"baidu" in resp

View file

@ -6,10 +6,13 @@ from abc import ABC
from dataclasses import dataclass
import pytest
from google.ai import generativelanguage as glm
from google.generativeai.types import content_types
from metagpt.provider.google_gemini_api import GeminiGPTAPI
from metagpt.config import CONFIG
from metagpt.provider.google_gemini_api import GeminiLLM
messages = [{"role": "user", "parts": "who are you"}]
CONFIG.gemini_api_key = "xx"
@dataclass
@ -17,25 +20,70 @@ class MockGeminiResponse(ABC):
text: str
default_resp = MockGeminiResponse(text="I'm gemini from google")
prompt_msg = "who are you"
messages = [{"role": "user", "parts": prompt_msg}]
resp_content = "I'm gemini from google"
default_resp = MockGeminiResponse(text=resp_content)
def mock_llm_ask(self, messages: list[dict]) -> MockGeminiResponse:
def mock_gemini_count_tokens(self, contents: content_types.ContentsType) -> glm.CountTokensResponse:
return glm.CountTokensResponse(total_tokens=20)
async def mock_gemini_count_tokens_async(self, contents: content_types.ContentsType) -> glm.CountTokensResponse:
return glm.CountTokensResponse(total_tokens=20)
def mock_gemini_generate_content(self, **kwargs) -> MockGeminiResponse:
return default_resp
def test_gemini_completion(mocker):
mocker.patch("metagpt.provider.google_gemini_api.GeminiGPTAPI.completion", mock_llm_ask)
resp = GeminiGPTAPI().completion(messages)
assert resp.text == default_resp.text
async def mock_gemini_generate_content_async(self, stream: bool = False, **kwargs) -> MockGeminiResponse:
if stream:
class Iterator(object):
async def __aiter__(self):
yield default_resp
async def mock_llm_aask(self, messgaes: list[dict]) -> MockGeminiResponse:
return default_resp
return Iterator()
else:
return default_resp
@pytest.mark.asyncio
async def test_gemini_acompletion(mocker):
mocker.patch("metagpt.provider.google_gemini_api.GeminiGPTAPI.acompletion", mock_llm_aask)
resp = await GeminiGPTAPI().acompletion(messages)
mocker.patch("metagpt.provider.google_gemini_api.GeminiGenerativeModel.count_tokens", mock_gemini_count_tokens)
mocker.patch(
"metagpt.provider.google_gemini_api.GeminiGenerativeModel.count_tokens_async", mock_gemini_count_tokens_async
)
mocker.patch("google.generativeai.generative_models.GenerativeModel.generate_content", mock_gemini_generate_content)
mocker.patch(
"google.generativeai.generative_models.GenerativeModel.generate_content_async",
mock_gemini_generate_content_async,
)
gemini_gpt = GeminiLLM()
assert gemini_gpt._user_msg(prompt_msg) == {"role": "user", "parts": [prompt_msg]}
assert gemini_gpt._assistant_msg(prompt_msg) == {"role": "model", "parts": [prompt_msg]}
usage = gemini_gpt.get_usage(messages, resp_content)
assert usage == {"prompt_tokens": 20, "completion_tokens": 20}
resp = gemini_gpt.completion(messages)
assert resp == default_resp
resp = await gemini_gpt.acompletion(messages)
assert resp.text == default_resp.text
resp = await gemini_gpt.aask(prompt_msg, stream=False)
assert resp == resp_content
resp = await gemini_gpt.acompletion_text(messages, stream=False)
assert resp == resp_content
resp = await gemini_gpt.acompletion_text(messages, stream=True)
assert resp == resp_content
resp = await gemini_gpt.aask(prompt_msg)
assert resp == resp_content

View file

@ -0,0 +1,31 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : the unittest of HumanProvider
import pytest
from metagpt.provider.human_provider import HumanProvider
resp_content = "test"
resp_exit = "exit"
@pytest.mark.asyncio
async def test_async_human_provider(mocker):
mocker.patch("builtins.input", lambda _: resp_content)
human_provider = HumanProvider()
resp = human_provider.ask(resp_content)
assert resp == resp_content
resp = await human_provider.aask(None)
assert resp_content == resp
mocker.patch("builtins.input", lambda _: resp_exit)
with pytest.raises(SystemExit):
human_provider.ask(resp_exit)
resp = await human_provider.acompletion([])
assert not resp
resp = await human_provider.acompletion_text([])
assert resp == ""

View file

@ -0,0 +1,14 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/12/28
@Author : mashenquan
@File : test_metagpt_api.py
"""
from metagpt.config import LLMProviderEnum
from metagpt.llm import LLM
def test_llm():
llm = LLM(provider=LLMProviderEnum.METAGPT)
assert llm

View file

@ -0,0 +1,17 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/8/30
@Author : mashenquan
@File : test_metagpt_llm_api.py
"""
from metagpt.provider.metagpt_api import MetaGPTLLM
def test_metagpt():
llm = MetaGPTLLM()
assert llm
if __name__ == "__main__":
test_metagpt()

View file

@ -2,32 +2,61 @@
# -*- coding: utf-8 -*-
# @Desc : the unittest of ollama api
import json
from typing import Any, Tuple
import pytest
from metagpt.provider.ollama_api import OllamaGPTAPI
from metagpt.config import CONFIG
from metagpt.provider.ollama_api import OllamaLLM
messages = [{"role": "user", "content": "who are you"}]
prompt_msg = "who are you"
messages = [{"role": "user", "content": prompt_msg}]
resp_content = "I'm ollama"
default_resp = {"message": {"role": "assistant", "content": resp_content}}
CONFIG.ollama_api_base = "http://xxx"
CONFIG.max_budget = 10
default_resp = {"message": {"role": "assisant", "content": "I'm ollama"}}
async def mock_ollama_arequest(self, stream: bool = False, **kwargs) -> Tuple[Any, Any, bool]:
if stream:
class Iterator(object):
events = [
b'{"message": {"role": "assistant", "content": "I\'m ollama"}, "done": false}',
b'{"prompt_eval_count": 20, "eval_count": 20, "done": true}',
]
def mock_llm_ask(self, messages: list[dict]) -> dict:
return default_resp
async def __aiter__(self):
for event in self.events:
yield event
def test_gemini_completion(mocker):
mocker.patch("metagpt.provider.ollama_api.OllamaGPTAPI.completion", mock_llm_ask)
resp = OllamaGPTAPI().completion(messages)
assert resp["message"]["content"] == default_resp["message"]["content"]
async def mock_llm_aask(self, messgaes: list[dict]) -> dict:
return default_resp
return Iterator(), None, None
else:
raw_default_resp = default_resp.copy()
raw_default_resp.update({"prompt_eval_count": 20, "eval_count": 20})
return json.dumps(raw_default_resp).encode(), None, None
@pytest.mark.asyncio
async def test_gemini_acompletion(mocker):
mocker.patch("metagpt.provider.ollama_api.OllamaGPTAPI.acompletion", mock_llm_aask)
resp = await OllamaGPTAPI().acompletion(messages)
mocker.patch("metagpt.provider.general_api_requestor.GeneralAPIRequestor.arequest", mock_ollama_arequest)
ollama_gpt = OllamaLLM()
resp = await ollama_gpt.acompletion(messages)
assert resp["message"]["content"] == default_resp["message"]["content"]
resp = await ollama_gpt.aask(prompt_msg, stream=False)
assert resp == resp_content
resp = await ollama_gpt.acompletion_text(messages, stream=False)
assert resp == resp_content
resp = await ollama_gpt.acompletion_text(messages, stream=True)
assert resp == resp_content
resp = await ollama_gpt.aask(prompt_msg)
assert resp == resp_content

View file

@ -0,0 +1,95 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc :
import pytest
from openai.types.chat.chat_completion import (
ChatCompletion,
ChatCompletionMessage,
Choice,
)
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from openai.types.chat.chat_completion_chunk import Choice as AChoice
from openai.types.chat.chat_completion_chunk import ChoiceDelta
from openai.types.completion_usage import CompletionUsage
from metagpt.config import CONFIG
from metagpt.provider.open_llm_api import OpenLLM
from metagpt.utils.cost_manager import Costs
CONFIG.max_budget = 10
CONFIG.calc_usage = True
resp_content = "I'm llama2"
default_resp = ChatCompletion(
id="cmpl-a6652c1bb181caae8dd19ad8",
model="llama-v2-13b-chat",
object="chat.completion",
created=1703302755,
choices=[
Choice(
finish_reason="stop",
index=0,
message=ChatCompletionMessage(role="assistant", content=resp_content),
logprobs=None,
)
],
)
default_resp_chunk = ChatCompletionChunk(
id=default_resp.id,
model=default_resp.model,
object="chat.completion.chunk",
created=default_resp.created,
choices=[
AChoice(
delta=ChoiceDelta(content=resp_content, role="assistant"),
finish_reason="stop",
index=0,
logprobs=None,
)
],
)
prompt_msg = "who are you"
messages = [{"role": "user", "content": prompt_msg}]
async def mock_openai_acompletions_create(self, stream: bool = False, **kwargs) -> ChatCompletionChunk:
if stream:
class Iterator(object):
async def __aiter__(self):
yield default_resp_chunk
return Iterator()
else:
return default_resp
@pytest.mark.asyncio
async def test_openllm_acompletion(mocker):
mocker.patch("openai.resources.chat.completions.AsyncCompletions.create", mock_openai_acompletions_create)
openllm_gpt = OpenLLM()
openllm_gpt.model = "llama-v2-13b-chat"
openllm_gpt._update_costs(usage=CompletionUsage(prompt_tokens=100, completion_tokens=100, total_tokens=200))
assert openllm_gpt.get_costs() == Costs(
total_prompt_tokens=100, total_completion_tokens=100, total_cost=0, total_budget=0
)
resp = await openllm_gpt.acompletion(messages)
assert resp.choices[0].message.content in resp_content
resp = await openllm_gpt.aask(prompt_msg, stream=False)
assert resp == resp_content
resp = await openllm_gpt.acompletion_text(messages, stream=False)
assert resp == resp_content
resp = await openllm_gpt.acompletion_text(messages, stream=True)
assert resp == resp_content
resp = await openllm_gpt.aask(prompt_msg)
assert resp == resp_content

View file

@ -2,13 +2,16 @@ from unittest.mock import Mock
import pytest
from metagpt.provider.openai_api import OpenAIGPTAPI
from metagpt.config import CONFIG
from metagpt.provider.openai_api import OpenAILLM
from metagpt.schema import UserMessage
CONFIG.openai_proxy = None
@pytest.mark.asyncio
async def test_aask_code():
llm = OpenAIGPTAPI()
llm = OpenAILLM()
msg = [{"role": "user", "content": "Write a python hello world code."}]
rsp = await llm.aask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"}
assert "language" in rsp
@ -18,7 +21,7 @@ async def test_aask_code():
@pytest.mark.asyncio
async def test_aask_code_str():
llm = OpenAIGPTAPI()
llm = OpenAILLM()
msg = "Write a python hello world code."
rsp = await llm.aask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"}
assert "language" in rsp
@ -28,7 +31,7 @@ async def test_aask_code_str():
@pytest.mark.asyncio
async def test_aask_code_Message():
llm = OpenAIGPTAPI()
llm = OpenAILLM()
msg = UserMessage("Write a python hello world code.")
rsp = await llm.aask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"}
assert "language" in rsp
@ -36,63 +39,26 @@ async def test_aask_code_Message():
assert len(rsp["code"]) > 0
def test_ask_code():
llm = OpenAIGPTAPI()
msg = [{"role": "user", "content": "Write a python hello world code."}]
rsp = llm.ask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"}
assert "language" in rsp
assert "code" in rsp
assert len(rsp["code"]) > 0
def test_ask_code_str():
llm = OpenAIGPTAPI()
msg = "Write a python hello world code."
rsp = llm.ask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"}
assert "language" in rsp
assert "code" in rsp
assert len(rsp["code"]) > 0
def test_ask_code_Message():
llm = OpenAIGPTAPI()
msg = UserMessage("Write a python hello world code.")
rsp = llm.ask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"}
assert "language" in rsp
assert "code" in rsp
assert len(rsp["code"]) > 0
def test_ask_code_list_Message():
llm = OpenAIGPTAPI()
msg = [UserMessage("a=[1,2,5,10,-10]"), UserMessage("写出求a中最大值的代码python")]
rsp = llm.ask_code(msg) # -> {'language': 'python', 'code': 'max_value = max(a)\nmax_value'}
assert "language" in rsp
assert "code" in rsp
assert len(rsp["code"]) > 0
def test_ask_code_list_str():
llm = OpenAIGPTAPI()
msg = ["a=[1,2,5,10,-10]", "写出求a中最大值的代码python"]
rsp = llm.ask_code(msg) # -> {'language': 'python', 'code': 'max_value = max(a)\nmax_value'}
print(rsp)
assert "language" in rsp
assert "code" in rsp
assert len(rsp["code"]) > 0
class TestOpenAI:
@pytest.fixture
def config(self):
return Mock(openai_api_key="test_key", openai_base_url="test_url", openai_proxy=None, openai_api_type="other")
return Mock(
openai_api_key="test_key",
OPENAI_API_KEY="test_key",
openai_base_url="test_url",
OPENAI_BASE_URL="test_url",
openai_proxy=None,
openai_api_type="other",
)
@pytest.fixture
def config_azure(self):
return Mock(
openai_api_key="test_key",
OPENAI_API_KEY="test_key",
openai_api_version="test_version",
openai_base_url="test_url",
OPENAI_BASE_URL="test_url",
openai_proxy=None,
openai_api_type="azure",
)
@ -101,7 +67,9 @@ class TestOpenAI:
def config_proxy(self):
return Mock(
openai_api_key="test_key",
OPENAI_API_KEY="test_key",
openai_base_url="test_url",
OPENAI_BASE_URL="test_url",
openai_proxy="http://proxy.com",
openai_api_type="other",
)
@ -110,40 +78,36 @@ class TestOpenAI:
def config_azure_proxy(self):
return Mock(
openai_api_key="test_key",
OPENAI_API_KEY="test_key",
openai_api_version="test_version",
openai_base_url="test_url",
OPENAI_BASE_URL="test_url",
openai_proxy="http://proxy.com",
openai_api_type="azure",
)
def test_make_client_kwargs_without_proxy(self, config):
instance = OpenAIGPTAPI()
instance = OpenAILLM()
instance.config = config
kwargs, async_kwargs = instance._make_client_kwargs()
kwargs = instance._make_client_kwargs()
assert kwargs == {"api_key": "test_key", "base_url": "test_url"}
assert async_kwargs == {"api_key": "test_key", "base_url": "test_url"}
assert "http_client" not in kwargs
assert "http_client" not in async_kwargs
def test_make_client_kwargs_without_proxy_azure(self, config_azure):
instance = OpenAIGPTAPI()
instance = OpenAILLM()
instance.config = config_azure
kwargs, async_kwargs = instance._make_client_kwargs()
assert kwargs == {"api_key": "test_key", "api_version": "test_version", "azure_endpoint": "test_url"}
assert async_kwargs == {"api_key": "test_key", "api_version": "test_version", "azure_endpoint": "test_url"}
kwargs = instance._make_client_kwargs()
assert kwargs == {"api_key": "test_key", "base_url": "test_url"}
assert "http_client" not in kwargs
assert "http_client" not in async_kwargs
def test_make_client_kwargs_with_proxy(self, config_proxy):
instance = OpenAIGPTAPI()
instance = OpenAILLM()
instance.config = config_proxy
kwargs, async_kwargs = instance._make_client_kwargs()
kwargs = instance._make_client_kwargs()
assert "http_client" in kwargs
assert "http_client" in async_kwargs
def test_make_client_kwargs_with_proxy_azure(self, config_azure_proxy):
instance = OpenAIGPTAPI()
instance = OpenAILLM()
instance.config = config_azure_proxy
kwargs, async_kwargs = instance._make_client_kwargs()
kwargs = instance._make_client_kwargs()
assert "http_client" in kwargs
assert "http_client" in async_kwargs

View file

@ -1,11 +1,61 @@
from metagpt.logs import logger
from metagpt.provider.spark_api import SparkAPI
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : the unittest of spark api
import pytest
from metagpt.config import CONFIG
from metagpt.provider.spark_api import GetMessageFromWeb, SparkLLM
CONFIG.spark_appid = "xxx"
CONFIG.spark_api_secret = "xxx"
CONFIG.spark_api_key = "xxx"
CONFIG.domain = "xxxxxx"
CONFIG.spark_url = "xxxx"
prompt_msg = "who are you"
resp_content = "I'm Spark"
def test_message():
llm = SparkAPI()
class MockWebSocketApp(object):
def __init__(self, ws_url, on_message=None, on_error=None, on_close=None, on_open=None):
pass
logger.info(llm.ask('只回答"收到了"这三个字。'))
result = llm.ask("写一篇五百字的日记")
logger.info(result)
assert len(result) > 100
def run_forever(self, sslopt=None):
pass
def test_get_msg_from_web(mocker):
mocker.patch("websocket.WebSocketApp", MockWebSocketApp)
get_msg_from_web = GetMessageFromWeb(text=prompt_msg)
assert get_msg_from_web.gen_params()["parameter"]["chat"]["domain"] == "xxxxxx"
ret = get_msg_from_web.run()
assert ret == ""
def mock_spark_get_msg_from_web_run(self) -> str:
return resp_content
@pytest.mark.asyncio
async def test_spark_acompletion(mocker):
mocker.patch("metagpt.provider.spark_api.GetMessageFromWeb.run", mock_spark_get_msg_from_web_run)
spark_gpt = SparkLLM()
resp = await spark_gpt.acompletion([])
assert resp == resp_content
resp = await spark_gpt.aask(prompt_msg, stream=False)
assert resp == resp_content
resp = await spark_gpt.acompletion_text([], stream=False)
assert resp == resp_content
resp = await spark_gpt.acompletion_text([], stream=True)
assert resp == resp_content
resp = await spark_gpt.aask(prompt_msg)
assert resp == resp_content

View file

@ -1,47 +1,89 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : the unittest of ZhiPuAIGPTAPI
# @Desc : the unittest of ZhiPuAILLM
import pytest
from zhipuai.utils.sse_client import Event
from metagpt.provider.zhipuai_api import ZhiPuAIGPTAPI
from metagpt.config import CONFIG
from metagpt.provider.zhipuai_api import ZhiPuAILLM
default_resp = {"code": 200, "data": {"choices": [{"role": "assistant", "content": "I'm chatglm-turbo"}]}}
CONFIG.zhipuai_api_key = "xxx.xxx"
messages = [{"role": "user", "content": "who are you"}]
prompt_msg = "who are you"
messages = [{"role": "user", "content": prompt_msg}]
resp_content = "I'm chatglm-turbo"
default_resp = {
"code": 200,
"data": {
"choices": [{"role": "assistant", "content": resp_content}],
"usage": {"prompt_tokens": 20, "completion_tokens": 20},
},
}
def mock_llm_ask(self, messages: list[dict]) -> dict:
def mock_zhipuai_invoke(**kwargs) -> dict:
return default_resp
def test_zhipuai_completion(mocker):
mocker.patch("metagpt.provider.zhipuai_api.ZhiPuAIGPTAPI.completion", mock_llm_ask)
resp = ZhiPuAIGPTAPI().completion(messages)
assert resp["code"] == 200
assert "chatglm-turbo" in resp["data"]["choices"][0]["content"]
async def mock_llm_aask(self, messgaes: list[dict], stream: bool = False) -> dict:
async def mock_zhipuai_ainvoke(**kwargs) -> dict:
return default_resp
async def mock_zhipuai_asse_invoke(**kwargs):
class MockResponse(object):
async def _aread(self):
class Iterator(object):
events = [
Event(id="xxx", event="add", data=resp_content, retry=0),
Event(
id="xxx",
event="finish",
data="",
meta='{"usage": {"completion_tokens": 20,"prompt_tokens": 20}}',
),
]
async def __aiter__(self):
for event in self.events:
yield event
async for chunk in Iterator():
yield chunk
async def async_events(self):
async for chunk in self._aread():
yield chunk
return MockResponse()
@pytest.mark.asyncio
async def test_zhipuai_acompletion(mocker):
mocker.patch("metagpt.provider.zhipuai_api.ZhiPuAIGPTAPI.acompletion_text", mock_llm_aask)
mocker.patch("metagpt.provider.zhipuai.zhipu_model_api.ZhiPuModelAPI.invoke", mock_zhipuai_invoke)
mocker.patch("metagpt.provider.zhipuai.zhipu_model_api.ZhiPuModelAPI.ainvoke", mock_zhipuai_ainvoke)
mocker.patch("metagpt.provider.zhipuai.zhipu_model_api.ZhiPuModelAPI.asse_invoke", mock_zhipuai_asse_invoke)
resp = await ZhiPuAIGPTAPI().acompletion_text(messages, stream=False)
zhipu_gpt = ZhiPuAILLM()
assert resp["code"] == 200
assert "chatglm-turbo" in resp["data"]["choices"][0]["content"]
resp = await zhipu_gpt.acompletion(messages)
assert resp["data"]["choices"][0]["content"] == resp_content
resp = await zhipu_gpt.aask(prompt_msg, stream=False)
assert resp == resp_content
resp = await zhipu_gpt.acompletion_text(messages, stream=False)
assert resp == resp_content
resp = await zhipu_gpt.acompletion_text(messages, stream=True)
assert resp == resp_content
resp = await zhipu_gpt.aask(prompt_msg)
assert resp == resp_content
def test_zhipuai_proxy(mocker):
import openai
from metagpt.config import CONFIG
CONFIG.openai_proxy = "http://127.0.0.1:8080"
_ = ZhiPuAIGPTAPI()
assert openai.proxy == CONFIG.openai_proxy
def test_zhipuai_proxy():
# CONFIG.openai_proxy = "http://127.0.0.1:8080"
_ = ZhiPuAILLM()
# assert openai.proxy == CONFIG.openai_proxy

View file

@ -0,0 +1,3 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc :

View file

@ -0,0 +1,26 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc :
import pytest
from metagpt.provider.zhipuai.async_sse_client import AsyncSSEClient
@pytest.mark.asyncio
async def test_async_sse_client():
class Iterator(object):
async def __aiter__(self):
yield b"data: test_value"
async_sse_client = AsyncSSEClient(event_source=Iterator())
async for event in async_sse_client.async_events():
assert event.data, "test_value"
class InvalidIterator(object):
async def __aiter__(self):
yield b"invalid: test_value"
async_sse_client = AsyncSSEClient(event_source=InvalidIterator())
async for event in async_sse_client.async_events():
assert not event

View file

@ -0,0 +1,44 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc :
from typing import Any, Tuple
import pytest
import zhipuai
from zhipuai.model_api.api import InvokeType
from zhipuai.utils.http_client import headers as zhipuai_default_headers
from metagpt.provider.zhipuai.zhipu_model_api import ZhiPuModelAPI
api_key = "xxx.xxx"
zhipuai.api_key = api_key
default_resp = b'{"result": "test response"}'
async def mock_requestor_arequest(self, **kwargs) -> Tuple[Any, Any, str]:
return default_resp, None, None
@pytest.mark.asyncio
async def test_zhipu_model_api(mocker):
header = ZhiPuModelAPI.get_header()
zhipuai_default_headers.update({"Authorization": api_key})
assert header == zhipuai_default_headers
sse_header = ZhiPuModelAPI.get_sse_header()
assert len(sse_header["Authorization"]) == 191
url_prefix, url_suffix = ZhiPuModelAPI.split_zhipu_api_url(InvokeType.SYNC, kwargs={"model": "chatglm_turbo"})
assert url_prefix == "https://open.bigmodel.cn/api"
assert url_suffix == "/paas/v3/model-api/chatglm_turbo/invoke"
mocker.patch("metagpt.provider.general_api_requestor.GeneralAPIRequestor.arequest", mock_requestor_arequest)
result = await ZhiPuModelAPI.arequest(
InvokeType.SYNC, stream=False, method="get", headers={}, kwargs={"model": "chatglm_turbo"}
)
assert result == default_resp
result = await ZhiPuModelAPI.ainvoke()
assert result["result"] == "test response"

View file

@ -3,8 +3,10 @@
"""
@Time : 2023/5/12 13:05
@Author : alexanderwu
@File : mock.py
@File : mock_markdown.py
"""
import json
from metagpt.actions import UserRequirement, WriteDesign, WritePRD, WriteTasks
from metagpt.schema import Message
@ -151,6 +153,32 @@ sequenceDiagram
```
"""
JSON_TASKS = {
"Logic Analysis": """
在这个项目中所有的模块都依赖于SearchEngine这是主入口其他的模块IndexRanking和Summary都通过它交互另外"Index"类又依赖于"KnowledgeBase"因为它需要从知识库中获取数据
- "main.py"包含"Main"是程序的入口点它调用"SearchEngine"进行搜索操作所以在其他任何模块之前"SearchEngine"必须首先被定义
- "search.py"定义了"SearchEngine"它依赖于"Index""Ranking""Summary"因此这些模块需要在"search.py"之前定义
- "index.py"定义了"Index"它从"knowledge_base.py"获取数据来创建索引所以"knowledge_base.py"需要在"index.py"之前定义
- "ranking.py""summary.py"相对独立只需确保在"search.py"之前定义
- "knowledge_base.py"是独立的模块可以优先开发
- "interface.py""user_feedback.py""security.py""testing.py""monitoring.py"看起来像是功能辅助模块可以在主要功能模块开发完成后并行开发
""",
"Task list": [
"smart_search_engine/knowledge_base.py",
"smart_search_engine/index.py",
"smart_search_engine/ranking.py",
"smart_search_engine/summary.py",
"smart_search_engine/search.py",
"smart_search_engine/main.py",
"smart_search_engine/interface.py",
"smart_search_engine/user_feedback.py",
"smart_search_engine/security.py",
"smart_search_engine/testing.py",
"smart_search_engine/monitoring.py",
],
}
TASKS = """## Logic Analysis
@ -256,3 +284,4 @@ class MockMessages:
prd = Message(role="Product Manager", content=PRD, cause_by=WritePRD)
system_design = Message(role="Architect", content=SYSTEM_DESIGN, cause_by=WriteDesign)
tasks = Message(role="Project Manager", content=TASKS, cause_by=WriteTasks)
json_tasks = Message(role="Project Manager", content=json.dumps(JSON_TASKS), cause_by=WriteTasks)

View file

@ -7,17 +7,39 @@
@Modified By: mashenquan, 2023-11-1. In accordance with Chapter 2.2.1 and 2.2.2 of RFC 116, utilize the new message
distribution feature for message handling.
"""
import uuid
import pytest
from metagpt.actions import WriteDesign, WritePRD
from metagpt.config import CONFIG
from metagpt.const import PRDS_FILE_REPO
from metagpt.logs import logger
from metagpt.roles import Architect
from metagpt.schema import Message
from metagpt.utils.common import any_to_str, awrite
from tests.metagpt.roles.mock import MockMessages
@pytest.mark.asyncio
@pytest.mark.usefixtures("llm_mock")
async def test_architect():
# Prerequisites
filename = uuid.uuid4().hex + ".json"
await awrite(CONFIG.git_repo.workdir / PRDS_FILE_REPO / filename, data=MockMessages.prd.content)
role = Architect()
role.put_message(MockMessages.req)
rsp = await role.run(MockMessages.prd)
rsp = await role.run(with_message=Message(content="", cause_by=WritePRD))
logger.info(rsp)
assert len(rsp.content) > 0
assert rsp.cause_by == any_to_str(WriteDesign)
# test update
rsp = await role.run(with_message=Message(content="", cause_by=WritePRD))
assert rsp
assert rsp.cause_by == any_to_str(WriteDesign)
assert len(rsp.content) > 0
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -0,0 +1,134 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/12/25
@Author : mashenquan
@File : test_asssistant.py
@Desc : Used by AgentStore.
"""
import pytest
from pydantic import BaseModel
from metagpt.actions.skill_action import SkillAction
from metagpt.actions.talk_action import TalkAction
from metagpt.config import CONFIG
from metagpt.logs import logger
from metagpt.memory.brain_memory import BrainMemory
from metagpt.roles.assistant import Assistant
from metagpt.schema import Message
from metagpt.utils.common import any_to_str
@pytest.mark.asyncio
@pytest.mark.usefixtures("llm_mock")
async def test_run():
CONFIG.language = "Chinese"
class Input(BaseModel):
memory: BrainMemory
language: str
agent_description: str
cause_by: str
inputs = [
{
"memory": {
"history": [
{
"content": "who is tulin",
"role": "user",
"id": "1",
},
{"content": "The one who eaten a poison apple.", "role": "assistant"},
],
"knowledge": [{"content": "tulin is a scientist."}],
"last_talk": "Do you have a poison apple?",
},
"language": "English",
"agent_description": "chatterbox",
"cause_by": any_to_str(TalkAction),
},
{
"memory": {
"history": [
{
"content": "can you draw me an picture?",
"role": "user",
"id": "1",
},
{"content": "Yes, of course. What do you want me to draw", "role": "assistant"},
],
"knowledge": [{"content": "tulin is a scientist."}],
"last_talk": "Draw me an apple.",
},
"language": "English",
"agent_description": "painter",
"cause_by": any_to_str(SkillAction),
},
]
CONFIG.agent_skills = [
{"id": 1, "name": "text_to_speech", "type": "builtin", "config": {}, "enabled": True},
{"id": 2, "name": "text_to_image", "type": "builtin", "config": {}, "enabled": True},
{"id": 3, "name": "ai_call", "type": "builtin", "config": {}, "enabled": True},
{"id": 3, "name": "data_analysis", "type": "builtin", "config": {}, "enabled": True},
{"id": 5, "name": "crawler", "type": "builtin", "config": {"engine": "ddg"}, "enabled": True},
{"id": 6, "name": "knowledge", "type": "builtin", "config": {}, "enabled": True},
{"id": 6, "name": "web_search", "type": "builtin", "config": {}, "enabled": True},
]
for i in inputs:
seed = Input(**i)
CONFIG.language = seed.language
CONFIG.agent_description = seed.agent_description
role = Assistant(language="Chinese")
role.memory = seed.memory # Restore historical conversation content.
while True:
has_action = await role.think()
if not has_action:
break
msg: Message = await role.act()
logger.info(msg)
assert msg
assert msg.cause_by == seed.cause_by
assert msg.content
@pytest.mark.parametrize(
"memory",
[
{
"history": [
{
"content": "can you draw me an picture?",
"role": "user",
"id": "1",
},
{"content": "Yes, of course. What do you want me to draw", "role": "assistant"},
],
"knowledge": [{"content": "tulin is a scientist."}],
"last_talk": "Draw me an apple.",
}
],
)
@pytest.mark.asyncio
async def test_memory(memory):
role = Assistant()
role.load_memory(memory)
val = role.get_memory()
assert val
await role.talk("draw apple")
agent_skills = CONFIG.agent_skills
CONFIG.agent_skills = []
try:
await role.think()
finally:
CONFIG.agent_skills = agent_skills
assert isinstance(role.rc.todo, TalkAction)
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -7,35 +7,52 @@
@Modified By: mashenquan, 2023-11-1. In accordance with Chapter 2.2.1 and 2.2.2 of RFC 116, utilize the new message
distribution feature for message handling.
"""
import json
from pathlib import Path
import pytest
from metagpt.actions import WriteCode, WriteTasks
from metagpt.config import CONFIG
from metagpt.const import (
PRDS_FILE_REPO,
REQUIREMENT_FILENAME,
SYSTEM_DESIGN_FILE_REPO,
TASK_FILE_REPO,
)
from metagpt.logs import logger
from metagpt.roles.engineer import Engineer
from metagpt.utils.common import CodeParser
from tests.metagpt.roles.mock import (
STRS_FOR_PARSING,
TASKS,
TASKS_TOMATO_CLOCK,
MockMessages,
)
from metagpt.schema import CodingContext, Message
from metagpt.utils.common import CodeParser, any_to_name, any_to_str, aread, awrite
from metagpt.utils.file_repository import FileRepository
from metagpt.utils.git_repository import ChangeType
from tests.metagpt.roles.mock import STRS_FOR_PARSING, TASKS, MockMessages
@pytest.mark.asyncio
@pytest.mark.usefixtures("llm_mock")
async def test_engineer():
engineer = Engineer()
# Prerequisites
rqno = "20231221155954.json"
await FileRepository.save_file(REQUIREMENT_FILENAME, content=MockMessages.req.content)
await FileRepository.save_file(rqno, relative_path=PRDS_FILE_REPO, content=MockMessages.prd.content)
await FileRepository.save_file(
rqno, relative_path=SYSTEM_DESIGN_FILE_REPO, content=MockMessages.system_design.content
)
await FileRepository.save_file(rqno, relative_path=TASK_FILE_REPO, content=MockMessages.json_tasks.content)
engineer.put_message(MockMessages.req)
engineer.put_message(MockMessages.prd)
engineer.put_message(MockMessages.system_design)
rsp = await engineer.run(MockMessages.tasks)
engineer = Engineer()
rsp = await engineer.run(Message(content="", cause_by=WriteTasks))
logger.info(rsp)
assert "all done." == rsp.content
assert rsp.cause_by == any_to_str(WriteCode)
src_file_repo = CONFIG.git_repo.new_file_repository(CONFIG.src_workspace)
assert src_file_repo.changed_files
def test_parse_str():
for idx, i in enumerate(STRS_FOR_PARSING):
text = CodeParser.parse_str(f"{idx+1}", i)
text = CodeParser.parse_str(f"{idx + 1}", i)
# logger.info(text)
assert text == "a"
@ -62,14 +79,11 @@ target_list = [
def test_parse_file_list():
tasks = CodeParser.parse_file_list("任务列表", TASKS)
tasks = CodeParser.parse_file_list("Task list", TASKS)
logger.info(tasks)
assert isinstance(tasks, list)
assert target_list == tasks
file_list = CodeParser.parse_file_list("Task list", TASKS_TOMATO_CLOCK, lang="python")
logger.info(file_list)
target_code = """task_list = [
"smart_search_engine/knowledge_base.py",
@ -88,7 +102,64 @@ target_code = """task_list = [
def test_parse_code():
code = CodeParser.parse_code("任务列表", TASKS, lang="python")
code = CodeParser.parse_code("Task list", TASKS, lang="python")
logger.info(code)
assert isinstance(code, str)
assert target_code == code
def test_todo():
role = Engineer()
assert role.todo == any_to_name(WriteCode)
@pytest.mark.asyncio
@pytest.mark.usefixtures("llm_mock")
async def test_new_coding_context():
# Prerequisites
demo_path = Path(__file__).parent / "../../data/demo_project"
deps = json.loads(await aread(demo_path / "dependencies.json"))
dependency = await CONFIG.git_repo.get_dependency()
for k, v in deps.items():
await dependency.update(k, set(v))
data = await aread(demo_path / "system_design.json")
rqno = "20231221155954.json"
await awrite(CONFIG.git_repo.workdir / SYSTEM_DESIGN_FILE_REPO / rqno, data)
data = await aread(demo_path / "tasks.json")
await awrite(CONFIG.git_repo.workdir / TASK_FILE_REPO / rqno, data)
CONFIG.src_workspace = Path(CONFIG.git_repo.workdir) / "game_2048"
src_file_repo = CONFIG.git_repo.new_file_repository(relative_path=CONFIG.src_workspace)
task_file_repo = CONFIG.git_repo.new_file_repository(relative_path=TASK_FILE_REPO)
design_file_repo = CONFIG.git_repo.new_file_repository(relative_path=SYSTEM_DESIGN_FILE_REPO)
filename = "game.py"
ctx_doc = await Engineer._new_coding_doc(
filename=filename,
src_file_repo=src_file_repo,
task_file_repo=task_file_repo,
design_file_repo=design_file_repo,
dependency=dependency,
)
assert ctx_doc
assert ctx_doc.filename == filename
assert ctx_doc.content
ctx = CodingContext.model_validate_json(ctx_doc.content)
assert ctx.filename == filename
assert ctx.design_doc
assert ctx.design_doc.content
assert ctx.task_doc
assert ctx.task_doc.content
assert ctx.code_doc
CONFIG.git_repo.add_change({f"{TASK_FILE_REPO}/{rqno}": ChangeType.UNTRACTED})
CONFIG.git_repo.commit("mock env")
await src_file_repo.save(filename=filename, content="content")
role = Engineer()
assert not role.code_todos
await role._new_code_actions()
assert role.code_todos
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -7,12 +7,12 @@
@File : test_invoice_ocr_assistant.py
"""
import json
from pathlib import Path
import pandas as pd
import pytest
from metagpt.const import DATA_PATH, TEST_DATA_PATH
from metagpt.roles.invoice_ocr_assistant import InvoiceOCRAssistant, InvoicePath
from metagpt.schema import Message
@ -23,41 +23,36 @@ from metagpt.schema import Message
[
(
"Invoicing date",
Path("../../data/invoices/invoice-1.pdf"),
Path("../../../data/invoice_table/invoice-1.xlsx"),
[{"收款人": "小明", "城市": "深圳", "总费用/元": 412.00, "开票日期": "2023年02月03日"}],
Path("invoices/invoice-1.pdf"),
Path("invoice_table/invoice-1.xlsx"),
{"收款人": "小明", "城市": "深圳", "总费用/元": 412.00, "开票日期": "2023年02月03日"},
),
(
"Invoicing date",
Path("../../data/invoices/invoice-2.png"),
Path("../../../data/invoice_table/invoice-2.xlsx"),
[{"收款人": "铁头", "城市": "广州", "总费用/元": 898.00, "开票日期": "2023年03月17日"}],
Path("invoices/invoice-2.png"),
Path("invoice_table/invoice-2.xlsx"),
{"收款人": "铁头", "城市": "广州", "总费用/元": 898.00, "开票日期": "2023年03月17日"},
),
(
"Invoicing date",
Path("../../data/invoices/invoice-3.jpg"),
Path("../../../data/invoice_table/invoice-3.xlsx"),
[{"收款人": "夏天", "城市": "福州市", "总费用/元": 2462.00, "开票日期": "2023年08月26日"}],
),
(
"Invoicing date",
Path("../../data/invoices/invoice-4.zip"),
Path("../../../data/invoice_table/invoice-4.xlsx"),
[
{"收款人": "小明", "城市": "深圳市", "总费用/元": 412.00, "开票日期": "2023年02月03日"},
{"收款人": "铁头", "城市": "广州市", "总费用/元": 898.00, "开票日期": "2023年03月17日"},
{"收款人": "夏天", "城市": "福州市", "总费用/元": 2462.00, "开票日期": "2023年08月26日"},
],
Path("invoices/invoice-3.jpg"),
Path("invoice_table/invoice-3.xlsx"),
{"收款人": "夏天", "城市": "福州", "总费用/元": 2462.00, "开票日期": "2023年08月26日"},
),
],
)
async def test_invoice_ocr_assistant(
query: str, invoice_path: Path, invoice_table_path: Path, expected_result: list[dict]
):
invoice_path = Path.cwd() / invoice_path
@pytest.mark.usefixtures("llm_mock")
async def test_invoice_ocr_assistant(query: str, invoice_path: Path, invoice_table_path: Path, expected_result: dict):
invoice_path = TEST_DATA_PATH / invoice_path
role = InvoiceOCRAssistant()
await role.run(Message(content=query, instruct_content=InvoicePath(file_path=invoice_path)))
invoice_table_path = Path.cwd() / invoice_table_path
invoice_table_path = DATA_PATH / invoice_table_path
df = pd.read_excel(invoice_table_path)
dict_result = df.to_dict(orient="records")
assert json.dumps(dict_result) == json.dumps(expected_result)
resp = df.to_dict(orient="records")
assert isinstance(resp, list)
assert len(resp) == 1
resp = resp[0]
assert expected_result["收款人"] == resp["收款人"]
assert expected_result["城市"] in resp["城市"]
assert float(expected_result["总费用/元"]) == float(resp["总费用/元"])
assert expected_result["开票日期"] == resp["开票日期"]

View file

@ -13,9 +13,10 @@ from tests.metagpt.roles.mock import MockMessages
@pytest.mark.asyncio
@pytest.mark.usefixtures("llm_mock")
async def test_product_manager():
product_manager = ProductManager()
rsp = await product_manager.handle(MockMessages.req)
rsp = await product_manager.run(MockMessages.req)
logger.info(rsp)
assert len(rsp.content) > 0
assert "Product Goals" in rsp.content
assert rsp.content == MockMessages.req.content

View file

@ -13,7 +13,8 @@ from tests.metagpt.roles.mock import MockMessages
@pytest.mark.asyncio
@pytest.mark.usefixtures("llm_mock")
async def test_project_manager():
project_manager = ProjectManager()
rsp = await project_manager.handle(MockMessages.system_design)
rsp = await project_manager.run(MockMessages.system_design)
logger.info(rsp)

View file

@ -5,3 +5,59 @@
@Author : alexanderwu
@File : test_qa_engineer.py
"""
from pathlib import Path
from typing import List
import pytest
from pydantic import Field
from metagpt.actions import DebugError, RunCode, WriteTest
from metagpt.actions.summarize_code import SummarizeCode
from metagpt.config import CONFIG
from metagpt.environment import Environment
from metagpt.roles import QaEngineer
from metagpt.schema import Message
from metagpt.utils.common import any_to_str, aread, awrite
async def test_qa():
# Prerequisites
demo_path = Path(__file__).parent / "../../data/demo_project"
CONFIG.src_workspace = Path(CONFIG.git_repo.workdir) / "qa/game_2048"
data = await aread(filename=demo_path / "game.py", encoding="utf-8")
await awrite(filename=CONFIG.src_workspace / "game.py", data=data, encoding="utf-8")
await awrite(filename=Path(CONFIG.git_repo.workdir) / "requirements.txt", data="")
class MockEnv(Environment):
msgs: List[Message] = Field(default_factory=list)
def publish_message(self, message: Message, peekable: bool = True) -> bool:
self.msgs.append(message)
return True
env = MockEnv()
role = QaEngineer()
role.set_env(env)
await role.run(with_message=Message(content="", cause_by=SummarizeCode))
assert env.msgs
assert env.msgs[0].cause_by == any_to_str(WriteTest)
msg = env.msgs[0]
env.msgs.clear()
await role.run(with_message=msg)
assert env.msgs
assert env.msgs[0].cause_by == any_to_str(RunCode)
msg = env.msgs[0]
env.msgs.clear()
await role.run(with_message=msg)
assert env.msgs
assert env.msgs[0].cause_by == any_to_str(DebugError)
msg = env.msgs[0]
env.msgs.clear()
role.test_round_allowed = 1
rsp = await role.run(with_message=msg)
assert "Exceeding" in rsp.content
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -28,7 +28,7 @@ async def mock_llm_ask(self, prompt: str, system_msgs):
async def test_researcher(mocker):
with TemporaryDirectory() as dirname:
topic = "dataiku vs. datarobot"
mocker.patch("metagpt.provider.base_gpt_api.BaseGPTAPI.aask", mock_llm_ask)
mocker.patch("metagpt.provider.base_llm.BaseLLM.aask", mock_llm_ask)
researcher.RESEARCH_PATH = Path(dirname)
await researcher.Researcher().run(topic)
assert (researcher.RESEARCH_PATH / f"{topic}.md").read_text().startswith("# Research Report")
@ -48,3 +48,7 @@ def test_write_report(mocker):
content = "# Research Report"
researcher.Researcher().write_report(topic, content)
assert (researcher.RESEARCH_PATH / f"{i+1}. metagpt.md").read_text().startswith("# Research Report")
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -1,6 +1,7 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : unittest of Role
import pytest
from metagpt.roles.role import Role
@ -8,4 +9,8 @@ from metagpt.roles.role import Role
def test_role_desc():
role = Role(profile="Sales", desc="Best Seller")
assert role.profile == "Sales"
assert role._setting.desc == "Best Seller"
assert role.desc == "Best Seller"
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -0,0 +1,158 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/7/27 13:25
@Author : mashenquan
@File : test_teacher.py
"""
import os
from typing import Dict, Optional
import pytest
from pydantic import BaseModel
from metagpt.config import CONFIG, Config
from metagpt.roles.teacher import Teacher
from metagpt.schema import Message
@pytest.mark.asyncio
async def test_init():
class Inputs(BaseModel):
name: str
profile: str
goal: str
constraints: str
desc: str
kwargs: Optional[Dict] = None
expect_name: str
expect_profile: str
expect_goal: str
expect_constraints: str
expect_desc: str
inputs = [
{
"name": "Lily{language}",
"expect_name": "Lily{language}",
"profile": "X {teaching_language}",
"expect_profile": "X {teaching_language}",
"goal": "Do {something_big}, {language}",
"expect_goal": "Do {something_big}, {language}",
"constraints": "Do in {key1}, {language}",
"expect_constraints": "Do in {key1}, {language}",
"kwargs": {},
"desc": "aaa{language}",
"expect_desc": "aaa{language}",
},
{
"name": "Lily{language}",
"expect_name": "LilyCN",
"profile": "X {teaching_language}",
"expect_profile": "X EN",
"goal": "Do {something_big}, {language}",
"expect_goal": "Do sleep, CN",
"constraints": "Do in {key1}, {language}",
"expect_constraints": "Do in HaHa, CN",
"kwargs": {"language": "CN", "key1": "HaHa", "something_big": "sleep", "teaching_language": "EN"},
"desc": "aaa{language}",
"expect_desc": "aaaCN",
},
]
env = os.environ.copy()
for i in inputs:
seed = Inputs(**i)
os.environ.clear()
os.environ.update(env)
CONFIG = Config()
CONFIG.set_context(seed.kwargs)
print(CONFIG.options)
assert bool("language" in seed.kwargs) == bool("language" in CONFIG.options)
teacher = Teacher(
name=seed.name,
profile=seed.profile,
goal=seed.goal,
constraints=seed.constraints,
desc=seed.desc,
)
assert teacher.name == seed.expect_name
assert teacher.desc == seed.expect_desc
assert teacher.profile == seed.expect_profile
assert teacher.goal == seed.expect_goal
assert teacher.constraints == seed.expect_constraints
assert teacher.course_title == "teaching_plan"
@pytest.mark.asyncio
async def test_new_file_name():
class Inputs(BaseModel):
lesson_title: str
ext: str
expect: str
inputs = [
{"lesson_title": "# @344\n12", "ext": ".md", "expect": "_344_12.md"},
{"lesson_title": "1#@$%!*&\\/:*?\"<>|\n\t '1", "ext": ".cc", "expect": "1_1.cc"},
]
for i in inputs:
seed = Inputs(**i)
result = Teacher.new_file_name(seed.lesson_title, seed.ext)
assert result == seed.expect
@pytest.mark.asyncio
@pytest.mark.usefixtures("llm_mock")
async def test_run():
CONFIG.set_context({"language": "Chinese", "teaching_language": "English"})
lesson = """
UNIT 1 Making New Friends
TOPIC 1 Welcome to China!
Section A
1a Listen and number the following names.
Jane Mari Kangkang Michael
Look, listen and understand. Then practice the conversation.
Work in groups. Introduce yourself using
I m ... Then practice 1a
with your own hometown or the following places.
1b Listen and number the following names
Jane Michael Maria Kangkang
1c Work in groups. Introduce yourself using I m ... Then practice 1a with your own hometown or the following places.
China the USA the UK Hong Kong Beijing
2a Look, listen and understand. Then practice the conversation
Hello!
Hello!
Hello!
Hello! Are you Maria?
No, Im not. Im Jane.
Oh, nice to meet you, Jane
Nice to meet you, too.
Hi, Maria!
Hi, Kangkang!
Welcome to China!
Thanks.
2b Work in groups. Make up a conversation with your own name and the
following structures.
A: Hello! / Good morning! / Hi! Im ... Are you ... ?
B: ...
3a Listen, say and trace
Aa Bb Cc Dd Ee Ff Gg
3b Listen and number the following letters. Then circle the letters with the same sound as Bb.
Aa Bb Cc Dd Ee Ff Gg
3c Match the big letters with the small ones. Then write them on the lines.
"""
teacher = Teacher()
rsp = await teacher.run(Message(content=lesson))
assert rsp
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -5,20 +5,26 @@
@Author : Stitch-z
@File : test_tutorial_assistant.py
"""
import aiofiles
import pytest
from metagpt.const import TUTORIAL_PATH
from metagpt.roles.tutorial_assistant import TutorialAssistant
@pytest.mark.asyncio
@pytest.mark.parametrize(("language", "topic"), [("Chinese", "Write a tutorial about Python")])
@pytest.mark.parametrize(("language", "topic"), [("Chinese", "Write a tutorial about pip")])
@pytest.mark.usefixtures("llm_mock")
async def test_tutorial_assistant(language: str, topic: str):
topic = "Write a tutorial about MySQL"
role = TutorialAssistant(language=language)
msg = await role.run(topic)
assert TUTORIAL_PATH.exists()
filename = msg.content
title = filename.split("/")[-1].split(".")[0]
async with aiofiles.open(filename, mode="r") as reader:
async with aiofiles.open(filename, mode="r", encoding="utf-8") as reader:
content = await reader.read()
assert content.startswith(f"# {title}")
assert "pip" in content
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -1,21 +0,0 @@
# -*- coding: utf-8 -*-
# @Date : 2023/7/22 02:40
# @Author : stellahong (stellahong@deepwisdom.ai)
#
from metagpt.roles import ProductManager
from metagpt.team import Team
from tests.metagpt.roles.ui_role import UI
def test_add_ui():
ui = UI()
assert ui.profile == "UI Design"
async def test_ui_role(idea: str, investment: float = 3.0, n_round: int = 5):
"""Run a startup. Be a boss."""
company = Team()
company.hire([ProductManager(), UI()])
company.invest(investment)
company.run_project(idea)
await company.run(n_round=n_round)

View file

@ -1,282 +0,0 @@
# -*- coding: utf-8 -*-
# @Date : 2023/7/15 16:40
# @Author : stellahong (stellahong@deepwisdom.ai)
# @Desc :
import os
import re
from functools import wraps
from importlib import import_module
from metagpt.actions import Action, ActionOutput, WritePRD
# from metagpt.const import WORKSPACE_ROOT
from metagpt.actions.action_node import ActionNode
from metagpt.config import CONFIG
from metagpt.logs import logger
from metagpt.roles import Role
from metagpt.schema import Message
from metagpt.tools.sd_engine import SDEngine
PROMPT_TEMPLATE = """
{context}
## Role
You are a UserInterface Designer; the goal is to finish a UI design according to PRD, give a design description, and select specified elements and UI style.
"""
UI_DESIGN_DESC = ActionNode(
key="UI Design Desc",
expected_type=str,
instruction="place the design objective here",
example="Snake games are classic and addictive games with simple yet engaging elements. Here are the main elements"
" commonly found in snake games",
)
SELECTED_ELEMENTS = ActionNode(
key="Selected Elements",
expected_type=list[str],
instruction="up to 5 specified elements, clear and simple",
example=[
"Game Grid: The game grid is a rectangular...",
"Snake: The player controls a snake that moves across the grid...",
"Food: Food items (often represented as small objects or differently colored blocks)",
"Score: The player's score increases each time the snake eats a piece of food. The longer the snake becomes, the higher the score.",
"Game Over: The game ends when the snake collides with itself or an obstacle. At this point, the player's final score is displayed, and they are given the option to restart the game.",
],
)
HTML_LAYOUT = ActionNode(
key="HTML Layout",
expected_type=str,
instruction="use standard HTML code",
example="""<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Snake Game</title>
<link rel="stylesheet" href="styles.css">
</head>
<body>
<div class="game-grid">
<!-- Snake will be dynamically generated here using JavaScript -->
</div>
<div class="food">
<!-- Food will be dynamically generated here using JavaScript -->
</div>
</body>
</html>
""",
)
CSS_STYLES = ActionNode(
key="CSS Styles",
expected_type=str,
instruction="use standard css code",
example="""body {
display: flex;
justify-content: center;
align-items: center;
height: 100vh;
margin: 0;
background-color: #f0f0f0;
}
.game-grid {
width: 400px;
height: 400px;
display: grid;
grid-template-columns: repeat(20, 1fr); /* Adjust to the desired grid size */
grid-template-rows: repeat(20, 1fr);
gap: 1px;
background-color: #222;
border: 1px solid #555;
}
.game-grid div {
width: 100%;
height: 100%;
background-color: #444;
}
.snake-segment {
background-color: #00cc66; /* Snake color */
}
.food {
width: 100%;
height: 100%;
background-color: #cc3300; /* Food color */
position: absolute;
}
/* Optional styles for a simple game over message */
.game-over {
position: absolute;
top: 50%;
left: 50%;
transform: translate(-50%, -50%);
font-size: 24px;
font-weight: bold;
color: #ff0000;
display: none;
}
""",
)
ANYTHING_UNCLEAR = ActionNode(
key="Anything UNCLEAR",
expected_type=str,
instruction="Mention any aspects of the project that are unclear and try to clarify them.",
example="...",
)
NODES = [
UI_DESIGN_DESC,
SELECTED_ELEMENTS,
HTML_LAYOUT,
CSS_STYLES,
ANYTHING_UNCLEAR,
]
UI_DESIGN_NODE = ActionNode.from_children("UI_DESIGN", NODES)
def load_engine(func):
"""Decorator to load an engine by file name and engine name."""
@wraps(func)
def wrapper(*args, **kwargs):
file_name, engine_name = func(*args, **kwargs)
engine_file = import_module(file_name, package="metagpt")
ip_module_cls = getattr(engine_file, engine_name)
try:
engine = ip_module_cls()
except:
engine = None
return engine
return wrapper
def parse(func):
"""Decorator to parse information using regex pattern."""
@wraps(func)
def wrapper(*args, **kwargs):
context, pattern = func(*args, **kwargs)
match = re.search(pattern, context, re.DOTALL)
if match:
text_info = match.group(1)
logger.info(text_info)
else:
text_info = context
logger.info("未找到匹配的内容")
return text_info
return wrapper
class UIDesign(Action):
"""Class representing the UI Design action."""
def __init__(self, name, context=None, llm=None):
super().__init__(name, context, llm) # 需要调用LLM进一步丰富UI设计的prompt
@parse
def parse_requirement(self, context: str):
"""Parse UI Design draft from the context using regex."""
pattern = r"## UI Design draft.*?\n(.*?)## Anything UNCLEAR"
return context, pattern
@parse
def parse_ui_elements(self, context: str):
"""Parse Selected Elements from the context using regex."""
pattern = r"## Selected Elements.*?\n(.*?)## HTML Layout"
return context, pattern
@parse
def parse_css_code(self, context: str):
pattern = r"```css.*?\n(.*?)## Anything UNCLEAR"
return context, pattern
@parse
def parse_html_code(self, context: str):
pattern = r"```html.*?\n(.*?)```"
return context, pattern
async def draw_icons(self, context, *args, **kwargs):
"""Draw icons using SDEngine."""
engine = SDEngine()
icon_prompts = self.parse_ui_elements(context)
icons = icon_prompts.split("\n")
icons = [s for s in icons if len(s.strip()) > 0]
prompts_batch = []
for icon_prompt in icons:
# fixme: 添加icon lora
prompt = engine.construct_payload(icon_prompt + ".<lora:WZ0710_AW81e-3_30e3b128d64T32_goon0.5>")
prompts_batch.append(prompt)
await engine.run_t2i(prompts_batch)
logger.info("Finish icon design using StableDiffusion API")
async def _save(self, css_content, html_content):
save_dir = CONFIG.workspace_path / "resources" / "codes"
if not os.path.exists(save_dir):
os.makedirs(save_dir, exist_ok=True)
# Save CSS and HTML content to files
css_file_path = save_dir / "ui_design.css"
html_file_path = save_dir / "ui_design.html"
css_file_path.write_text(css_content)
html_file_path.write_text(html_content)
async def run(self, requirements: list[Message], *args, **kwargs) -> ActionOutput:
"""Run the UI Design action."""
# fixme: update prompt (根据需求细化prompt
context = requirements[-1].content
ui_design_draft = self.parse_requirement(context=context)
# todo: parse requirements str
prompt = PROMPT_TEMPLATE.format(context=ui_design_draft)
logger.info(prompt)
ui_describe = await UI_DESIGN_NODE.fill(prompt)
logger.info(ui_describe.content)
logger.info(ui_describe.instruct_content)
css = self.parse_css_code(context=ui_describe.content)
html = self.parse_html_code(context=ui_describe.content)
await self._save(css_content=css, html_content=html)
await self.draw_icons(ui_describe.content)
return ui_describe
class UI(Role):
"""Class representing the UI Role."""
def __init__(
self,
name="Catherine",
profile="UI Design",
goal="Finish a workable and good User Interface design based on a product design",
constraints="Give clear layout description and use standard icons to finish the design",
skills=["SD"],
):
super().__init__(name, profile, goal, constraints)
self.load_skills(skills)
self._init_actions([UIDesign])
self._watch([WritePRD])
@load_engine
def load_sd_engine(self):
"""Load the SDEngine."""
file_name = ".tools.sd_engine"
engine_name = "SDEngine"
return file_name, engine_name
def load_skills(self, skills):
"""Load skills for the UI Role."""
# todo: 添加其他出图engine
for skill in skills:
if skill == "SD":
self.sd_engine = self.load_sd_engine()
logger.info(f"load skill engine {self.sd_engine}")

View file

@ -10,18 +10,24 @@ from metagpt.llm import LLM
def test_action_serialize():
action = Action()
ser_action_dict = action.dict()
ser_action_dict = action.model_dump()
assert "name" in ser_action_dict
# assert "llm" not in ser_action_dict # not export
assert "llm" not in ser_action_dict # not export
assert "__module_class_name" not in ser_action_dict
action = Action(name="test")
ser_action_dict = action.model_dump()
assert "test" in ser_action_dict["name"]
@pytest.mark.asyncio
@pytest.mark.usefixtures("llm_mock")
async def test_action_deserialize():
action = Action()
serialized_data = action.dict()
serialized_data = action.model_dump()
new_action = Action(**serialized_data)
assert new_action.name == ""
assert new_action.llm == LLM()
assert new_action.name == "Action"
assert isinstance(new_action.llm, type(LLM()))
assert len(await new_action._aask("who are you")) > 0

View file

@ -10,19 +10,20 @@ from metagpt.roles.architect import Architect
def test_architect_serialize():
role = Architect()
ser_role_dict = role.dict(by_alias=True)
ser_role_dict = role.model_dump(by_alias=True)
assert "name" in ser_role_dict
assert "_states" in ser_role_dict
assert "_actions" in ser_role_dict
assert "states" in ser_role_dict
assert "actions" in ser_role_dict
@pytest.mark.asyncio
@pytest.mark.usefixtures("llm_mock")
async def test_architect_deserialize():
role = Architect()
ser_role_dict = role.dict(by_alias=True)
ser_role_dict = role.model_dump(by_alias=True)
new_role = Architect(**ser_role_dict)
# new_role = Architect.deserialize(ser_role_dict)
assert new_role.name == "Bob"
assert len(new_role._actions) == 1
assert isinstance(new_role._actions[0], Action)
await new_role._actions[0].run(with_messages="write a cli snake game")
assert len(new_role.actions) == 1
assert isinstance(new_role.actions[0], Action)
await new_role.actions[0].run(with_messages="write a cli snake game")

View file

@ -13,6 +13,7 @@ from metagpt.schema import Message
from metagpt.utils.common import any_to_str
from tests.metagpt.serialize_deserialize.test_serdeser_base import (
ActionOK,
ActionRaise,
RoleC,
serdeser_path,
)
@ -20,14 +21,15 @@ from tests.metagpt.serialize_deserialize.test_serdeser_base import (
def test_env_serialize():
env = Environment()
ser_env_dict = env.dict()
ser_env_dict = env.model_dump()
assert "roles" in ser_env_dict
assert len(ser_env_dict["roles"]) == 0
def test_env_deserialize():
env = Environment()
env.publish_message(message=Message(content="test env serialize"))
ser_env_dict = env.dict()
ser_env_dict = env.model_dump()
new_env = Environment(**ser_env_dict)
assert len(new_env.roles) == 0
assert len(new_env.history) == 25
@ -47,16 +49,16 @@ def test_environment_serdeser():
environment.add_role(role_c)
environment.publish_message(message)
ser_data = environment.dict()
ser_data = environment.model_dump()
assert ser_data["roles"]["Role C"]["name"] == "RoleC"
new_env: Environment = Environment(**ser_data)
assert len(new_env.roles) == 1
assert list(new_env.roles.values())[0]._states == list(environment.roles.values())[0]._states
assert list(new_env.roles.values())[0]._actions == list(environment.roles.values())[0]._actions
assert isinstance(list(environment.roles.values())[0]._actions[0], ActionOK)
assert type(list(new_env.roles.values())[0]._actions[0]) == ActionOK
assert list(new_env.roles.values())[0].states == list(environment.roles.values())[0].states
assert isinstance(list(environment.roles.values())[0].actions[0], ActionOK)
assert type(list(new_env.roles.values())[0].actions[0]) == ActionOK
assert type(list(new_env.roles.values())[0].actions[1]) == ActionRaise
def test_environment_serdeser_v2():
@ -64,13 +66,13 @@ def test_environment_serdeser_v2():
pm = ProjectManager()
environment.add_role(pm)
ser_data = environment.dict()
ser_data = environment.model_dump()
new_env: Environment = Environment(**ser_data)
role = new_env.get_role(pm.profile)
assert isinstance(role, ProjectManager)
assert isinstance(role._actions[0], WriteTasks)
assert isinstance(list(new_env.roles.values())[0]._actions[0], WriteTasks)
assert isinstance(role.actions[0], WriteTasks)
assert isinstance(list(new_env.roles.values())[0].actions[0], WriteTasks)
def test_environment_serdeser_save():
@ -85,4 +87,4 @@ def test_environment_serdeser_save():
new_env: Environment = Environment.deserialize(stg_path)
assert len(new_env.roles) == 1
assert type(list(new_env.roles.values())[0]._actions[0]) == ActionOK
assert type(list(new_env.roles.values())[0].actions[0]) == ActionOK

View file

@ -25,7 +25,7 @@ def test_memory_serdeser():
memory = Memory()
memory.add_batch([msg1, msg2])
ser_data = memory.dict()
ser_data = memory.model_dump()
new_memory = Memory(**ser_data)
assert new_memory.count() == 2
@ -35,6 +35,9 @@ def test_memory_serdeser():
assert new_memory.storage[-1].cause_by == any_to_str(WriteDesign)
assert new_msg2.role == "Boss"
memory = Memory(storage=[msg1, msg2], index={msg1.cause_by: [msg1], msg2.cause_by: [msg2]})
assert memory.count() == 2
def test_memory_serdeser_save():
msg1 = Message(role="User", content="write a 2048 game", cause_by=UserRequirement)

View file

@ -0,0 +1,58 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : unittest of polymorphic conditions
from pydantic import BaseModel, ConfigDict, SerializeAsAny
from metagpt.actions import Action
from tests.metagpt.serialize_deserialize.test_serdeser_base import (
ActionOKV2,
ActionPass,
)
class ActionSubClasses(BaseModel):
actions: list[SerializeAsAny[Action]] = []
class ActionSubClassesNoSAA(BaseModel):
"""without SerializeAsAny"""
model_config = ConfigDict(arbitrary_types_allowed=True)
actions: list[Action] = []
def test_serialize_as_any():
"""test subclasses of action with different fields in ser&deser"""
# ActionOKV2 with a extra field `extra_field`
action_subcls = ActionSubClasses(actions=[ActionOKV2(), ActionPass()])
action_subcls_dict = action_subcls.model_dump()
assert action_subcls_dict["actions"][0]["extra_field"] == ActionOKV2().extra_field
def test_no_serialize_as_any():
# ActionOKV2 with a extra field `extra_field`
action_subcls = ActionSubClassesNoSAA(actions=[ActionOKV2(), ActionPass()])
action_subcls_dict = action_subcls.model_dump()
# without `SerializeAsAny`, it will serialize as Action
assert "extra_field" not in action_subcls_dict["actions"][0]
def test_polymorphic():
_ = ActionOKV2(
**{"name": "ActionOKV2", "context": "", "prefix": "", "desc": "", "extra_field": "ActionOKV2 Extra Info"}
)
action_subcls = ActionSubClasses(actions=[ActionOKV2(), ActionPass()])
action_subcls_dict = action_subcls.model_dump()
assert "__module_class_name" in action_subcls_dict["actions"][0]
new_action_subcls = ActionSubClasses(**action_subcls_dict)
assert isinstance(new_action_subcls.actions[0], ActionOKV2)
assert isinstance(new_action_subcls.actions[1], ActionPass)
new_action_subcls = ActionSubClasses.model_validate(action_subcls_dict)
assert isinstance(new_action_subcls.actions[0], ActionOKV2)
assert isinstance(new_action_subcls.actions[1], ActionPass)

View file

@ -0,0 +1,20 @@
# -*- coding: utf-8 -*-
# @Desc :
import pytest
from metagpt.actions.action_node import ActionNode
from metagpt.actions.prepare_interview import PrepareInterview
@pytest.mark.asyncio
@pytest.mark.usefixtures("llm_mock")
async def test_action_deserialize():
action = PrepareInterview()
serialized_data = action.model_dump()
assert serialized_data["name"] == "PrepareInterview"
new_action = PrepareInterview(**serialized_data)
assert new_action.name == "PrepareInterview"
assert type(await new_action.run("python developer")) == ActionNode

View file

@ -10,12 +10,13 @@ from metagpt.schema import Message
@pytest.mark.asyncio
@pytest.mark.usefixtures("llm_mock")
async def test_product_manager_deserialize():
role = ProductManager()
ser_role_dict = role.dict(by_alias=True)
ser_role_dict = role.model_dump(by_alias=True)
new_role = ProductManager(**ser_role_dict)
assert new_role.name == "Alice"
assert len(new_role._actions) == 2
assert isinstance(new_role._actions[0], Action)
await new_role._actions[0].run([Message(content="write a cli snake game")])
assert len(new_role.actions) == 2
assert isinstance(new_role.actions[0], Action)
await new_role.actions[0].run([Message(content="write a cli snake game")])

View file

@ -11,20 +11,21 @@ from metagpt.roles.project_manager import ProjectManager
def test_project_manager_serialize():
role = ProjectManager()
ser_role_dict = role.dict(by_alias=True)
ser_role_dict = role.model_dump(by_alias=True)
assert "name" in ser_role_dict
assert "_states" in ser_role_dict
assert "_actions" in ser_role_dict
assert "states" in ser_role_dict
assert "actions" in ser_role_dict
@pytest.mark.asyncio
@pytest.mark.usefixtures("llm_mock")
async def test_project_manager_deserialize():
role = ProjectManager()
ser_role_dict = role.dict(by_alias=True)
ser_role_dict = role.model_dump(by_alias=True)
new_role = ProjectManager(**ser_role_dict)
assert new_role.name == "Eve"
assert len(new_role._actions) == 1
assert isinstance(new_role._actions[0], Action)
assert isinstance(new_role._actions[0], WriteTasks)
# await new_role._actions[0].run(context="write a cli snake game")
assert len(new_role.actions) == 1
assert isinstance(new_role.actions[0], Action)
assert isinstance(new_role.actions[0], WriteTasks)
# await new_role.actions[0].run(context="write a cli snake game")

View file

@ -0,0 +1,22 @@
# -*- coding: utf-8 -*-
# @Desc :
import pytest
from metagpt.actions import CollectLinks
from metagpt.roles.researcher import Researcher
@pytest.mark.asyncio
async def test_tutorial_assistant_deserialize():
role = Researcher()
ser_role_dict = role.model_dump()
assert "name" in ser_role_dict
assert "language" in ser_role_dict
new_role = Researcher(**ser_role_dict)
assert new_role.language == "en-us"
assert len(new_role.actions) == 3
assert isinstance(new_role.actions[0], CollectLinks)
# todo: 需要测试不同的action失败下记忆是否正常保存

View file

@ -6,6 +6,7 @@
import shutil
import pytest
from pydantic import BaseModel, SerializeAsAny
from metagpt.actions import WriteCode
from metagpt.actions.add_requirement import UserRequirement
@ -17,48 +18,68 @@ from metagpt.roles.role import Role
from metagpt.schema import Message
from metagpt.utils.common import format_trackback_info
from tests.metagpt.serialize_deserialize.test_serdeser_base import (
ActionOK,
RoleA,
RoleB,
RoleC,
RoleD,
serdeser_path,
)
def test_roles():
role_a = RoleA()
assert len(role_a._rc.watch) == 1
assert len(role_a.rc.watch) == 1
role_b = RoleB()
assert len(role_a._rc.watch) == 1
assert len(role_b._rc.watch) == 1
assert len(role_a.rc.watch) == 1
assert len(role_b.rc.watch) == 1
role_d = RoleD(actions=[ActionOK()])
assert len(role_d.actions) == 1
def test_role_subclasses():
"""test subclasses of role with same fields in ser&deser"""
class RoleSubClasses(BaseModel):
roles: list[SerializeAsAny[Role]] = []
role_subcls = RoleSubClasses(roles=[RoleA(), RoleB()])
role_subcls_dict = role_subcls.model_dump()
new_role_subcls = RoleSubClasses(**role_subcls_dict)
assert isinstance(new_role_subcls.roles[0], RoleA)
assert isinstance(new_role_subcls.roles[1], RoleB)
def test_role_serialize():
role = Role()
ser_role_dict = role.dict(by_alias=True)
ser_role_dict = role.model_dump()
assert "name" in ser_role_dict
assert "_states" in ser_role_dict
assert "_actions" in ser_role_dict
assert "states" in ser_role_dict
assert "actions" in ser_role_dict
def test_engineer_serialize():
role = Engineer()
ser_role_dict = role.dict(by_alias=True)
ser_role_dict = role.model_dump()
assert "name" in ser_role_dict
assert "_states" in ser_role_dict
assert "_actions" in ser_role_dict
assert "states" in ser_role_dict
assert "actions" in ser_role_dict
@pytest.mark.asyncio
@pytest.mark.usefixtures("llm_mock")
async def test_engineer_deserialize():
role = Engineer(use_code_review=True)
ser_role_dict = role.dict(by_alias=True)
ser_role_dict = role.model_dump()
new_role = Engineer(**ser_role_dict)
assert new_role.name == "Alex"
assert new_role.use_code_review is True
assert len(new_role._actions) == 1
assert isinstance(new_role._actions[0], WriteCode)
# await new_role._actions[0].run(context="write a cli snake game", filename="test_code")
assert len(new_role.actions) == 1
assert isinstance(new_role.actions[0], WriteCode)
# await new_role.actions[0].run(context="write a cli snake game", filename="test_code")
def test_role_serdeser_save():
@ -76,6 +97,7 @@ def test_role_serdeser_save():
@pytest.mark.asyncio
@pytest.mark.usefixtures("llm_mock")
async def test_role_serdeser_interrupt():
role_c = RoleC()
shutil.rmtree(SERDESER_PATH.joinpath("team"), ignore_errors=True)
@ -87,10 +109,14 @@ async def test_role_serdeser_interrupt():
logger.error(f"Exception in `role_a.run`, detail: {format_trackback_info()}")
role_c.serialize(stg_path)
assert role_c._rc.memory.count() == 1
assert role_c.rc.memory.count() == 1
new_role_a: Role = Role.deserialize(stg_path)
assert new_role_a._rc.state == 1
assert new_role_a.rc.state == 1
with pytest.raises(Exception):
await role_c.run(with_message=Message(content="demo", cause_by=UserRequirement))
await new_role_a.run(with_message=Message(content="demo", cause_by=UserRequirement))
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -4,9 +4,12 @@
from metagpt.actions.action_node import ActionNode
from metagpt.actions.write_code import WriteCode
from metagpt.schema import Message
from metagpt.schema import Document, Documents, Message
from metagpt.utils.common import any_to_str
from tests.metagpt.serialize_deserialize.test_serdeser_base import MockMessage
from tests.metagpt.serialize_deserialize.test_serdeser_base import (
MockICMessage,
MockMessage,
)
def test_message_serdeser():
@ -15,14 +18,24 @@ def test_message_serdeser():
ic_obj = ActionNode.create_model_class("code", out_mapping)
message = Message(content="code", instruct_content=ic_obj(**out_data), role="engineer", cause_by=WriteCode)
ser_data = message.dict()
ser_data = message.model_dump()
assert ser_data["cause_by"] == "metagpt.actions.write_code.WriteCode"
assert ser_data["instruct_content"]["class"] == "code"
new_message = Message(**ser_data)
assert new_message.cause_by == any_to_str(WriteCode)
assert new_message.cause_by in [any_to_str(WriteCode)]
assert new_message.instruct_content == ic_obj(**out_data)
assert new_message.instruct_content != ic_obj(**out_data) # TODO find why `!=`
assert new_message.instruct_content.model_dump() == ic_obj(**out_data).model_dump()
message = Message(content="test_ic", instruct_content=MockICMessage())
ser_data = message.model_dump()
new_message = Message(**ser_data)
assert new_message.instruct_content != MockICMessage() # TODO
message = Message(content="test_documents", instruct_content=Documents(docs={"doc1": Document(content="test doc")}))
ser_data = message.model_dump()
assert "class" in ser_data["instruct_content"]
def test_message_without_postprocess():
@ -31,8 +44,9 @@ def test_message_without_postprocess():
out_data = {"field1": ["field1 value1", "field1 value2"]}
ic_obj = ActionNode.create_model_class("code", out_mapping)
message = MockMessage(content="code", instruct_content=ic_obj(**out_data))
ser_data = message.dict()
assert ser_data["instruct_content"] == {"field1": ["field1 value1", "field1 value2"]}
ser_data = message.model_dump()
assert ser_data["instruct_content"] == {}
ser_data["instruct_content"] = None
new_message = MockMessage(**ser_data)
assert new_message.instruct_content != ic_obj(**out_data)

View file

@ -4,6 +4,7 @@
import asyncio
from pathlib import Path
from typing import Optional
from pydantic import BaseModel, Field
@ -15,15 +16,19 @@ from metagpt.roles.role import Role, RoleReactMode
serdeser_path = Path(__file__).absolute().parent.joinpath("..", "..", "data", "serdeser_storage")
class MockICMessage(BaseModel):
content: str = "test_ic"
class MockMessage(BaseModel):
"""to test normal dict without postprocess"""
content: str = ""
instruct_content: BaseModel = Field(default=None)
instruct_content: Optional[BaseModel] = Field(default=None)
class ActionPass(Action):
name: str = Field(default="ActionPass")
name: str = "ActionPass"
async def run(self, messages: list["Message"]) -> ActionOutput:
await asyncio.sleep(5) # sleep to make other roles can watch the executed Message
@ -35,7 +40,7 @@ class ActionPass(Action):
class ActionOK(Action):
name: str = Field(default="ActionOK")
name: str = "ActionOK"
async def run(self, messages: list["Message"]) -> str:
await asyncio.sleep(5)
@ -43,12 +48,17 @@ class ActionOK(Action):
class ActionRaise(Action):
name: str = Field(default="ActionRaise")
name: str = "ActionRaise"
async def run(self, messages: list["Message"]) -> str:
raise RuntimeError("parse error in ActionRaise")
class ActionOKV2(Action):
name: str = "ActionOKV2"
extra_field: str = "ActionOKV2 Extra Info"
class RoleA(Role):
name: str = Field(default="RoleA")
profile: str = Field(default="Role A")
@ -71,7 +81,7 @@ class RoleB(Role):
super(RoleB, self).__init__(**kwargs)
self._init_actions([ActionOK, ActionRaise])
self._watch([ActionPass])
self._rc.react_mode = RoleReactMode.BY_ORDER
self.rc.react_mode = RoleReactMode.BY_ORDER
class RoleC(Role):
@ -84,4 +94,12 @@ class RoleC(Role):
super(RoleC, self).__init__(**kwargs)
self._init_actions([ActionOK, ActionRaise])
self._watch([UserRequirement])
self._rc.react_mode = RoleReactMode.BY_ORDER
self.rc.react_mode = RoleReactMode.BY_ORDER
self.rc.memory.ignore_id = True
class RoleD(Role):
name: str = Field(default="RoleD")
profile: str = Field(default="Role D")
goal: str = "RoleD's goal"
constraints: str = "RoleD's constraints"

View file

@ -0,0 +1,24 @@
# -*- coding: utf-8 -*-
# @Desc :
import pytest
from metagpt.roles.sk_agent import SkAgent
def test_sk_agent_serialize():
role = SkAgent()
ser_role_dict = role.model_dump(exclude={"import_semantic_skill_from_directory", "import_skill"})
assert "name" in ser_role_dict
assert "planner" in ser_role_dict
@pytest.mark.asyncio
async def test_sk_agent_deserialize():
role = SkAgent()
ser_role_dict = role.model_dump(exclude={"import_semantic_skill_from_directory", "import_skill"})
assert "name" in ser_role_dict
assert "planner" in ser_role_dict
new_role = SkAgent(**ser_role_dict)
assert new_role.name == "Sunshine"
assert len(new_role.actions) == 1

View file

@ -33,8 +33,8 @@ def test_team_deserialize():
]
)
assert len(company.env.get_roles()) == 3
ser_company = company.dict()
new_company = Team(**ser_company)
ser_company = company.model_dump()
new_company = Team.model_validate(ser_company)
assert len(new_company.env.get_roles()) == 3
assert new_company.env.get_role(pm.profile) is not None
@ -47,6 +47,7 @@ def test_team_deserialize():
def test_team_serdeser_save():
company = Team()
company.hire([RoleC()])
stg_path = serdeser_path.joinpath("team")
@ -71,13 +72,13 @@ async def test_team_recover():
company.run_project(idea)
await company.run(n_round=4)
ser_data = company.dict()
ser_data = company.model_dump()
new_company = Team(**ser_data)
new_role_c = new_company.env.get_role(role_c.profile)
# assert new_role_c._rc.memory == role_c._rc.memory # TODO
assert new_role_c._rc.env != role_c._rc.env # TODO
assert type(list(new_company.env.roles.values())[0]._actions[0]) == ActionOK
new_company.env.get_role(role_c.profile)
# assert new_role_c.rc.memory == role_c.rc.memory # TODO
# assert new_role_c.rc.env != role_c.rc.env # TODO
assert type(list(new_company.env.roles.values())[0].actions[0]) == ActionOK
new_company.run_project(idea)
await new_company.run(n_round=4)
@ -97,17 +98,18 @@ async def test_team_recover_save():
new_company = Team.deserialize(stg_path)
new_role_c = new_company.env.get_role(role_c.profile)
# assert new_role_c._rc.memory == role_c._rc.memory
assert new_role_c._rc.env != role_c._rc.env
# assert new_role_c.rc.memory == role_c.rc.memory
# assert new_role_c.rc.env != role_c.rc.env
assert new_role_c.recovered != role_c.recovered # here cause previous ut is `!=`
assert new_role_c._rc.todo != role_c._rc.todo # serialize exclude `_rc.todo`
assert new_role_c._rc.news != role_c._rc.news # serialize exclude `_rc.news`
assert new_role_c.rc.todo != role_c.rc.todo # serialize exclude `rc.todo`
assert new_role_c.rc.news != role_c.rc.news # serialize exclude `rc.news`
new_company.run_project(idea)
await new_company.run(n_round=4)
@pytest.mark.asyncio
@pytest.mark.usefixtures("llm_mock")
async def test_team_recover_multi_roles_save():
idea = "write a snake game"
stg_path = SERDESER_PATH.joinpath("team")
@ -116,10 +118,6 @@ async def test_team_recover_multi_roles_save():
role_a = RoleA()
role_b = RoleB()
assert role_a.subscription == {"tests.metagpt.serialize_deserialize.test_serdeser_base.RoleA", "RoleA"}
assert role_b.subscription == {"tests.metagpt.serialize_deserialize.test_serdeser_base.RoleB", "RoleB"}
assert role_b._rc.watch == {"tests.metagpt.serialize_deserialize.test_serdeser_base.ActionPass"}
company = Team()
company.hire([role_a, role_b])
company.run_project(idea)
@ -130,6 +128,6 @@ async def test_team_recover_multi_roles_save():
new_company = Team.deserialize(stg_path)
new_company.run_project(idea)
assert new_company.env.get_role(role_b.profile)._rc.state == 1
assert new_company.env.get_role(role_b.profile).rc.state == 1
await new_company.run(n_round=4)

View file

@ -0,0 +1,20 @@
# -*- coding: utf-8 -*-
# @Desc :
import pytest
from metagpt.actions.write_tutorial import WriteDirectory
from metagpt.roles.tutorial_assistant import TutorialAssistant
@pytest.mark.asyncio
async def test_tutorial_assistant_deserialize():
role = TutorialAssistant()
ser_role_dict = role.model_dump()
assert "name" in ser_role_dict
assert "language" in ser_role_dict
assert "topic" in ser_role_dict
new_role = TutorialAssistant(**ser_role_dict)
assert new_role.name == "Stitch"
assert len(new_role.actions) == 1
assert isinstance(new_role.actions[0], WriteDirectory)

View file

@ -6,27 +6,26 @@
import pytest
from metagpt.actions import WriteCode
from metagpt.llm import LLM
from metagpt.schema import CodingContext, Document
def test_write_design_serialize():
action = WriteCode()
ser_action_dict = action.dict()
ser_action_dict = action.model_dump()
assert ser_action_dict["name"] == "WriteCode"
# assert "llm" in ser_action_dict # not export
assert "llm" not in ser_action_dict # not export
@pytest.mark.asyncio
@pytest.mark.usefixtures("llm_mock")
async def test_write_code_deserialize():
context = CodingContext(
filename="test_code.py", design_doc=Document(content="write add function to calculate two numbers")
)
doc = Document(content=context.json())
doc = Document(content=context.model_dump_json())
action = WriteCode(context=doc)
serialized_data = action.dict()
serialized_data = action.model_dump()
new_action = WriteCode(**serialized_data)
assert new_action.name == "WriteCode"
assert new_action.llm == LLM()
await action.run()

View file

@ -5,11 +5,11 @@
import pytest
from metagpt.actions import WriteCodeReview
from metagpt.llm import LLM
from metagpt.schema import CodingContext, Document
@pytest.mark.asyncio
@pytest.mark.usefixtures("llm_mock")
async def test_write_code_review_deserialize():
code_content = """
def div(a: int, b: int = 0):
@ -22,11 +22,10 @@ def div(a: int, b: int = 0):
)
action = WriteCodeReview(context=context)
serialized_data = action.dict()
serialized_data = action.model_dump()
assert serialized_data["name"] == "WriteCodeReview"
new_action = WriteCodeReview(**serialized_data)
assert new_action.name == "WriteCodeReview"
assert new_action.llm == LLM()
await new_action.run()

View file

@ -5,38 +5,37 @@
import pytest
from metagpt.actions import WriteDesign, WriteTasks
from metagpt.llm import LLM
def test_write_design_serialize():
action = WriteDesign()
ser_action_dict = action.dict()
ser_action_dict = action.model_dump()
assert "name" in ser_action_dict
# assert "llm" in ser_action_dict # not export
assert "llm" not in ser_action_dict # not export
def test_write_task_serialize():
action = WriteTasks()
ser_action_dict = action.dict()
ser_action_dict = action.model_dump()
assert "name" in ser_action_dict
# assert "llm" in ser_action_dict # not export
assert "llm" not in ser_action_dict # not export
@pytest.mark.asyncio
@pytest.mark.usefixtures("llm_mock")
async def test_write_design_deserialize():
action = WriteDesign()
serialized_data = action.dict()
serialized_data = action.model_dump()
new_action = WriteDesign(**serialized_data)
assert new_action.name == ""
assert new_action.llm == LLM()
assert new_action.name == "WriteDesign"
await new_action.run(with_messages="write a cli snake game")
@pytest.mark.asyncio
@pytest.mark.usefixtures("llm_mock")
async def test_write_task_deserialize():
action = WriteTasks()
serialized_data = action.dict()
serialized_data = action.model_dump()
new_action = WriteTasks(**serialized_data)
assert new_action.name == "CreateTasks"
assert new_action.llm == LLM()
assert new_action.name == "WriteTasks"
await new_action.run(with_messages="write a cli snake game")

View file

@ -0,0 +1,45 @@
# -*- coding: utf-8 -*-
# @Desc :
import pytest
from metagpt.actions.write_docstring import WriteDocstring
code = """
def add_numbers(a: int, b: int):
return a + b
class Person:
def __init__(self, name: str, age: int):
self.name = name
self.age = age
def greet(self):
return f"Hello, my name is {self.name} and I am {self.age} years old."
"""
@pytest.mark.asyncio
@pytest.mark.parametrize(
("style", "part"),
[
("google", "Args:"),
("numpy", "Parameters"),
("sphinx", ":param name:"),
],
ids=["google", "numpy", "sphinx"],
)
@pytest.mark.usefixtures("llm_mock")
async def test_action_deserialize(style: str, part: str):
action = WriteDocstring()
serialized_data = action.model_dump()
assert "name" in serialized_data
assert serialized_data["desc"] == "Write docstring for code."
new_action = WriteDocstring(**serialized_data)
assert new_action.name == "WriteDocstring"
assert new_action.desc == "Write docstring for code."
ret = await new_action.run(code, style=style)
assert part in ret

Some files were not shown because too many files have changed in this diff Show more