mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-05 14:55:18 +02:00
Merge branch 'main' into dev_updated
This commit is contained in:
commit
853086924a
429 changed files with 24237 additions and 5835 deletions
143
tests/metagpt/actions/mock_json.py
Normal file
143
tests/metagpt/actions/mock_json.py
Normal file
|
|
@ -0,0 +1,143 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/12/24 20:32
|
||||
@Author : alexanderwu
|
||||
@File : mock_json.py
|
||||
"""
|
||||
|
||||
PRD = {
|
||||
"Language": "zh_cn",
|
||||
"Programming Language": "Python",
|
||||
"Original Requirements": "写一个简单的cli贪吃蛇",
|
||||
"Project Name": "cli_snake",
|
||||
"Product Goals": ["创建一个简单易用的贪吃蛇游戏", "提供良好的用户体验", "支持不同难度级别"],
|
||||
"User Stories": [
|
||||
"作为玩家,我希望能够选择不同的难度级别",
|
||||
"作为玩家,我希望在每局游戏结束后能够看到我的得分",
|
||||
"作为玩家,我希望在输掉游戏后能够重新开始",
|
||||
"作为玩家,我希望看到简洁美观的界面",
|
||||
"作为玩家,我希望能够在手机上玩游戏",
|
||||
],
|
||||
"Competitive Analysis": ["贪吃蛇游戏A:界面简单,缺乏响应式特性", "贪吃蛇游戏B:美观且响应式的界面,显示最高得分", "贪吃蛇游戏C:响应式界面,显示最高得分,但有很多广告"],
|
||||
"Competitive Quadrant Chart": 'quadrantChart\n title "Reach and engagement of campaigns"\n x-axis "Low Reach" --> "High Reach"\n y-axis "Low Engagement" --> "High Engagement"\n quadrant-1 "We should expand"\n quadrant-2 "Need to promote"\n quadrant-3 "Re-evaluate"\n quadrant-4 "May be improved"\n "Game A": [0.3, 0.6]\n "Game B": [0.45, 0.23]\n "Game C": [0.57, 0.69]\n "Game D": [0.78, 0.34]\n "Game E": [0.40, 0.34]\n "Game F": [0.35, 0.78]\n "Our Target Product": [0.5, 0.6]',
|
||||
"Requirement Analysis": "",
|
||||
"Requirement Pool": [["P0", "主要代码..."], ["P0", "游戏算法..."]],
|
||||
"UI Design draft": "基本功能描述,简单的风格和布局。",
|
||||
"Anything UNCLEAR": "",
|
||||
}
|
||||
|
||||
|
||||
DESIGN = {
|
||||
"Implementation approach": "我们将使用Python编程语言,并选择合适的开源框架来实现贪吃蛇游戏。我们将分析需求中的难点,并选择合适的开源框架来简化开发流程。",
|
||||
"File list": ["main.py", "game.py"],
|
||||
"Data structures and interfaces": "\nclassDiagram\n class Game {\n -int width\n -int height\n -int score\n -int speed\n -List<Point> snake\n -Point food\n +__init__(width: int, height: int, speed: int)\n +start_game()\n +change_direction(direction: str)\n +game_over()\n +update_snake()\n +update_food()\n +check_collision()\n }\n class Point {\n -int x\n -int y\n +__init__(x: int, y: int)\n }\n Game --> Point\n",
|
||||
"Program call flow": "\nsequenceDiagram\n participant M as Main\n participant G as Game\n M->>G: start_game()\n M->>G: change_direction(direction)\n G->>G: update_snake()\n G->>G: update_food()\n G->>G: check_collision()\n G-->>G: game_over()\n",
|
||||
"Anything UNCLEAR": "",
|
||||
}
|
||||
|
||||
|
||||
TASKS = {
|
||||
"Required Python packages": ["pygame==2.0.1"],
|
||||
"Required Other language third-party packages": ["No third-party dependencies required"],
|
||||
"Logic Analysis": [
|
||||
["game.py", "Contains Game class and related functions for game logic"],
|
||||
["main.py", "Contains the main function, imports Game class from game.py"],
|
||||
],
|
||||
"Task list": ["game.py", "main.py"],
|
||||
"Full API spec": "",
|
||||
"Shared Knowledge": "'game.py' contains functions shared across the project.",
|
||||
"Anything UNCLEAR": "",
|
||||
}
|
||||
|
||||
|
||||
FILE_GAME = """## game.py
|
||||
|
||||
import pygame
|
||||
import random
|
||||
|
||||
class Point:
|
||||
def __init__(self, x: int, y: int):
|
||||
self.x = x
|
||||
self.y = y
|
||||
|
||||
class Game:
|
||||
def __init__(self, width: int, height: int, speed: int):
|
||||
self.width = width
|
||||
self.height = height
|
||||
self.score = 0
|
||||
self.speed = speed
|
||||
self.snake = [Point(width // 2, height // 2)]
|
||||
self.food = self._create_food()
|
||||
|
||||
def start_game(self):
|
||||
pygame.init()
|
||||
self._display = pygame.display.set_mode((self.width, self.height))
|
||||
pygame.display.set_caption('Snake Game')
|
||||
self._clock = pygame.time.Clock()
|
||||
self._running = True
|
||||
|
||||
while self._running:
|
||||
self._handle_events()
|
||||
self._update_snake()
|
||||
self._update_food()
|
||||
self._check_collision()
|
||||
self._draw_screen()
|
||||
self._clock.tick(self.speed)
|
||||
|
||||
def change_direction(self, direction: str):
|
||||
# Update the direction of the snake based on user input
|
||||
pass
|
||||
|
||||
def game_over(self):
|
||||
# Display game over message and handle game over logic
|
||||
pass
|
||||
|
||||
def _create_food(self) -> Point:
|
||||
# Create and return a new food Point
|
||||
return Point(random.randint(0, self.width - 1), random.randint(0, self.height - 1))
|
||||
|
||||
def _handle_events(self):
|
||||
for event in pygame.event.get():
|
||||
if event.type == pygame.QUIT:
|
||||
self._running = False
|
||||
|
||||
def _update_snake(self):
|
||||
# Update the position of the snake based on its direction
|
||||
pass
|
||||
|
||||
def _update_food(self):
|
||||
# Update the position of the food if the snake eats it
|
||||
pass
|
||||
|
||||
def _check_collision(self):
|
||||
# Check for collision between the snake and the walls or itself
|
||||
pass
|
||||
|
||||
def _draw_screen(self):
|
||||
self._display.fill((0, 0, 0)) # Clear the screen
|
||||
# Draw the snake and food on the screen
|
||||
pygame.display.update()
|
||||
|
||||
if __name__ == "__main__":
|
||||
game = Game(800, 600, 15)
|
||||
game.start_game()
|
||||
"""
|
||||
|
||||
FILE_GAME_CR_1 = """## Code Review: game.py
|
||||
1. Yes, the code is implemented as per the requirements. It initializes the game with the specified width, height, and speed, and starts the game loop.
|
||||
2. No, the logic for handling events and updating the snake, food, and collision is not implemented. To correct this, we need to implement the logic for handling events, updating the snake and food positions, and checking for collisions.
|
||||
3. Yes, the existing code follows the "Data structures and interfaces" by defining the Game and Point classes with the specified attributes and methods.
|
||||
4. No, several functions such as change_direction, game_over, _update_snake, _update_food, and _check_collision are not implemented. These functions need to be implemented to complete the game logic.
|
||||
5. Yes, all necessary pre-dependencies have been imported. The required pygame package is imported at the beginning of the file.
|
||||
6. No, methods from other files are not being reused as there are no other files being imported or referenced in the current code.
|
||||
|
||||
## Actions
|
||||
1. Implement the logic for handling events, updating the snake and food positions, and checking for collisions within the Game class.
|
||||
2. Implement the change_direction and game_over methods to handle user input and game over logic.
|
||||
3. Implement the _update_snake method to update the position of the snake based on its direction.
|
||||
4. Implement the _update_food method to update the position of the food if the snake eats it.
|
||||
5. Implement the _check_collision method to check for collision between the snake and the walls or itself.
|
||||
|
||||
## Code Review Result
|
||||
LBTM"""
|
||||
|
|
@ -3,7 +3,7 @@
|
|||
"""
|
||||
@Time : 2023/5/18 23:51
|
||||
@Author : alexanderwu
|
||||
@File : mock.py
|
||||
@File : mock_markdown.py
|
||||
"""
|
||||
|
||||
PRD_SAMPLE = """## Original Requirements
|
||||
|
|
@ -90,7 +90,7 @@ Python's in-built data structures like lists and dictionaries will be used exten
|
|||
|
||||
For testing, we can use the PyTest framework. This is a mature full-featured Python testing tool that helps you write better programs.
|
||||
|
||||
## Python package name:
|
||||
## Project Name:
|
||||
```python
|
||||
"adventure_game"
|
||||
```
|
||||
|
|
@ -100,7 +100,7 @@ For testing, we can use the PyTest framework. This is a mature full-featured Pyt
|
|||
file_list = ["main.py", "room.py", "player.py", "game.py", "object.py", "puzzle.py", "test_game.py"]
|
||||
```
|
||||
|
||||
## Data structures and interface definitions:
|
||||
## Data structures and interfaces:
|
||||
```mermaid
|
||||
classDiagram
|
||||
class Room{
|
||||
|
|
@ -209,7 +209,7 @@ Shared knowledge for this project includes understanding the basic principles of
|
|||
"""
|
||||
```
|
||||
|
||||
## Anything UNCLEAR: Provide as Plain text. Make clear here. For example, don't forget a main entry. don't forget to init 3rd party libs.
|
||||
## Anything UNCLEAR: Provide as Plain text. Try to clarify it. For example, don't forget a main entry. don't forget to init 3rd party libs.
|
||||
```python
|
||||
"""
|
||||
The original requirements did not specify whether the game should have a save/load feature, multiplayer support, or any specific graphical user interface. More information on these aspects could help in further refining the product design and requirements.
|
||||
|
|
@ -311,12 +311,10 @@ TASKS = [
|
|||
"添加数据API:接受用户输入的文档库,对文档库进行索引\n- 使用MeiliSearch连接并添加文档库",
|
||||
"搜索API:接收用户输入的关键词,返回相关的搜索结果\n- 使用MeiliSearch连接并使用接口获得对应数据",
|
||||
"多条件筛选API:接收用户选择的筛选条件,返回符合条件的搜索结果。\n- 使用MeiliSearch进行筛选并返回符合条件的搜索结果",
|
||||
"智能推荐API:根据用户的搜索历史记录和搜索行为,推荐相关的搜索结果。"
|
||||
"智能推荐API:根据用户的搜索历史记录和搜索行为,推荐相关的搜索结果。",
|
||||
]
|
||||
|
||||
TASKS_2 = [
|
||||
"完成main.py的功能"
|
||||
]
|
||||
TASKS_2 = ["完成main.py的功能"]
|
||||
|
||||
SEARCH_CODE_SAMPLE = """
|
||||
import requests
|
||||
|
|
@ -460,7 +458,7 @@ if __name__ == '__main__':
|
|||
print('No results found.')
|
||||
'''
|
||||
|
||||
MEILI_CODE = '''import meilisearch
|
||||
MEILI_CODE = """import meilisearch
|
||||
from typing import List
|
||||
|
||||
|
||||
|
|
@ -496,9 +494,9 @@ if __name__ == '__main__':
|
|||
|
||||
# 添加文档库到搜索引擎
|
||||
search_engine.add_documents(books_data_source, documents)
|
||||
'''
|
||||
"""
|
||||
|
||||
MEILI_ERROR = '''/usr/local/bin/python3.9 /Users/alexanderwu/git/metagpt/examples/search/meilisearch_index.py
|
||||
MEILI_ERROR = """/usr/local/bin/python3.9 /Users/alexanderwu/git/metagpt/examples/search/meilisearch_index.py
|
||||
Traceback (most recent call last):
|
||||
File "/Users/alexanderwu/git/metagpt/examples/search/meilisearch_index.py", line 44, in <module>
|
||||
search_engine.add_documents(books_data_source, documents)
|
||||
|
|
@ -506,7 +504,7 @@ Traceback (most recent call last):
|
|||
index = self.client.get_or_create_index(index_name)
|
||||
AttributeError: 'Client' object has no attribute 'get_or_create_index'
|
||||
|
||||
Process finished with exit code 1'''
|
||||
Process finished with exit code 1"""
|
||||
|
||||
MEILI_CODE_REFINED = """
|
||||
"""
|
||||
|
|
@ -5,9 +5,37 @@
|
|||
@Author : alexanderwu
|
||||
@File : test_action.py
|
||||
"""
|
||||
from metagpt.actions import Action, WritePRD, WriteTest
|
||||
import pytest
|
||||
|
||||
from metagpt.actions import Action, ActionType, WritePRD, WriteTest
|
||||
|
||||
|
||||
def test_action_repr():
|
||||
actions = [Action(), WriteTest(), WritePRD()]
|
||||
assert "WriteTest" in str(actions)
|
||||
|
||||
|
||||
def test_action_type():
|
||||
assert ActionType.WRITE_PRD.value == WritePRD
|
||||
assert ActionType.WRITE_TEST.value == WriteTest
|
||||
assert ActionType.WRITE_PRD.name == "WRITE_PRD"
|
||||
assert ActionType.WRITE_TEST.name == "WRITE_TEST"
|
||||
|
||||
|
||||
def test_simple_action():
|
||||
action = Action(name="AlexSay", instruction="Express your opinion with emotion and don't repeat it")
|
||||
assert action.name == "AlexSay"
|
||||
assert action.node.instruction == "Express your opinion with emotion and don't repeat it"
|
||||
|
||||
|
||||
def test_empty_action():
|
||||
action = Action()
|
||||
assert action.name == "Action"
|
||||
assert not action.node
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_action_exception():
|
||||
action = Action()
|
||||
with pytest.raises(NotImplementedError):
|
||||
await action.run()
|
||||
|
|
|
|||
171
tests/metagpt/actions/test_action_node.py
Normal file
171
tests/metagpt/actions/test_action_node.py
Normal file
|
|
@ -0,0 +1,171 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/12/23 15:49
|
||||
@Author : alexanderwu
|
||||
@File : test_action_node.py
|
||||
"""
|
||||
from typing import List, Tuple
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.environment import Environment
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.roles import Role
|
||||
from metagpt.schema import Message
|
||||
from metagpt.team import Team
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_debate_two_roles():
|
||||
action1 = Action(name="AlexSay", instruction="Express your opinion with emotion and don't repeat it")
|
||||
action2 = Action(name="BobSay", instruction="Express your opinion with emotion and don't repeat it")
|
||||
biden = Role(
|
||||
name="Alex", profile="Democratic candidate", goal="Win the election", actions=[action1], watch=[action2]
|
||||
)
|
||||
trump = Role(
|
||||
name="Bob", profile="Republican candidate", goal="Win the election", actions=[action2], watch=[action1]
|
||||
)
|
||||
env = Environment(desc="US election live broadcast")
|
||||
team = Team(investment=10.0, env=env, roles=[biden, trump])
|
||||
|
||||
history = await team.run(idea="Topic: climate change. Under 80 words per message.", send_to="Alex", n_round=3)
|
||||
assert "Alex" in history
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_debate_one_role_in_env():
|
||||
action = Action(name="Debate", instruction="Express your opinion with emotion and don't repeat it")
|
||||
biden = Role(name="Alex", profile="Democratic candidate", goal="Win the election", actions=[action])
|
||||
env = Environment(desc="US election live broadcast")
|
||||
team = Team(investment=10.0, env=env, roles=[biden])
|
||||
history = await team.run(idea="Topic: climate change. Under 80 words per message.", send_to="Alex", n_round=3)
|
||||
assert "Alex" in history
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_debate_one_role():
|
||||
action = Action(name="Debate", instruction="Express your opinion with emotion and don't repeat it")
|
||||
biden = Role(name="Alex", profile="Democratic candidate", goal="Win the election", actions=[action])
|
||||
msg: Message = await biden.run("Topic: climate change. Under 80 words per message.")
|
||||
|
||||
assert len(msg.content) > 10
|
||||
assert msg.sent_from == "metagpt.roles.role.Role"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_action_node_one_layer():
|
||||
node = ActionNode(key="key-a", expected_type=str, instruction="instruction-b", example="example-c")
|
||||
|
||||
raw_template = node.compile(context="123", schema="raw", mode="auto")
|
||||
json_template = node.compile(context="123", schema="json", mode="auto")
|
||||
markdown_template = node.compile(context="123", schema="markdown", mode="auto")
|
||||
node_dict = node.to_dict()
|
||||
|
||||
assert "123" in raw_template
|
||||
assert "instruction" in raw_template
|
||||
|
||||
assert "123" in json_template
|
||||
assert "format example" in json_template
|
||||
assert "constraint" in json_template
|
||||
assert "action" in json_template
|
||||
assert "[/" in json_template
|
||||
|
||||
assert "123" in markdown_template
|
||||
assert "key-a" in markdown_template
|
||||
|
||||
assert node_dict["key-a"] == "instruction-b"
|
||||
assert "key-a" in repr(node)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_action_node_two_layer():
|
||||
node_a = ActionNode(key="reasoning", expected_type=str, instruction="reasoning step by step", example="")
|
||||
node_b = ActionNode(key="answer", expected_type=str, instruction="the final answer", example="")
|
||||
|
||||
root = ActionNode.from_children(key="detail answer", nodes=[node_a, node_b])
|
||||
assert "reasoning" in root.children
|
||||
assert node_b in root.children.values()
|
||||
|
||||
# FIXME: ADD MARKDOWN SUPPORT. NEED TO TUNE MARKDOWN SYMBOL FIRST.
|
||||
answer1 = await root.fill(context="what's the answer to 123+456?", schema="json", strgy="simple", llm=LLM())
|
||||
assert "579" in answer1.content
|
||||
|
||||
answer2 = await root.fill(context="what's the answer to 123+456?", schema="json", strgy="complex", llm=LLM())
|
||||
assert "579" in answer2.content
|
||||
|
||||
|
||||
t_dict = {
|
||||
"Required Python third-party packages": '"""\nflask==1.1.2\npygame==2.0.1\n"""\n',
|
||||
"Required Other language third-party packages": '"""\nNo third-party packages required for other languages.\n"""\n',
|
||||
"Full API spec": '"""\nopenapi: 3.0.0\ninfo:\n title: Web Snake Game API\n version: 1.0.0\npaths:\n /game:\n get:\n summary: Get the current game state\n responses:\n \'200\':\n description: A JSON object of the game state\n post:\n summary: Send a command to the game\n requestBody:\n required: true\n content:\n application/json:\n schema:\n type: object\n properties:\n command:\n type: string\n responses:\n \'200\':\n description: A JSON object of the updated game state\n"""\n',
|
||||
"Logic Analysis": [
|
||||
["app.py", "Main entry point for the Flask application. Handles HTTP requests and responses."],
|
||||
["game.py", "Contains the Game and Snake classes. Handles the game logic."],
|
||||
["static/js/script.js", "Handles user interactions and updates the game UI."],
|
||||
["static/css/styles.css", "Defines the styles for the game UI."],
|
||||
["templates/index.html", "The main page of the web application. Displays the game UI."],
|
||||
],
|
||||
"Task list": ["game.py", "app.py", "static/css/styles.css", "static/js/script.js", "templates/index.html"],
|
||||
"Shared Knowledge": "\"\"\"\n'game.py' contains the Game and Snake classes which are responsible for the game logic. The Game class uses an instance of the Snake class.\n\n'app.py' is the main entry point for the Flask application. It creates an instance of the Game class and handles HTTP requests and responses.\n\n'static/js/script.js' is responsible for handling user interactions and updating the game UI based on the game state returned by 'app.py'.\n\n'static/css/styles.css' defines the styles for the game UI.\n\n'templates/index.html' is the main page of the web application. It displays the game UI and loads 'static/js/script.js' and 'static/css/styles.css'.\n\"\"\"\n",
|
||||
"Anything UNCLEAR": "We need clarification on how the high score should be stored. Should it persist across sessions (stored in a database or a file) or should it reset every time the game is restarted? Also, should the game speed increase as the snake grows, or should it remain constant throughout the game?",
|
||||
}
|
||||
|
||||
t_dict_min = {
|
||||
"Required Python third-party packages": '"""\nflask==1.1.2\npygame==2.0.1\n"""\n',
|
||||
}
|
||||
|
||||
WRITE_TASKS_OUTPUT_MAPPING = {
|
||||
"Required Python third-party packages": (str, ...),
|
||||
"Required Other language third-party packages": (str, ...),
|
||||
"Full API spec": (str, ...),
|
||||
"Logic Analysis": (List[Tuple[str, str]], ...),
|
||||
"Task list": (List[str], ...),
|
||||
"Shared Knowledge": (str, ...),
|
||||
"Anything UNCLEAR": (str, ...),
|
||||
}
|
||||
|
||||
WRITE_TASKS_OUTPUT_MAPPING_MISSING = {
|
||||
"Required Python third-party packages": (str, ...),
|
||||
}
|
||||
|
||||
|
||||
def test_create_model_class():
|
||||
test_class = ActionNode.create_model_class("test_class", WRITE_TASKS_OUTPUT_MAPPING)
|
||||
assert test_class.__name__ == "test_class"
|
||||
|
||||
output = test_class(**t_dict)
|
||||
print(output.schema())
|
||||
assert output.schema()["title"] == "test_class"
|
||||
assert output.schema()["type"] == "object"
|
||||
assert output.schema()["properties"]["Full API spec"]
|
||||
|
||||
|
||||
def test_create_model_class_with_fields_unrecognized():
|
||||
test_class = ActionNode.create_model_class("test_class", WRITE_TASKS_OUTPUT_MAPPING_MISSING)
|
||||
assert test_class.__name__ == "test_class"
|
||||
|
||||
_ = test_class(**t_dict) # just warning
|
||||
|
||||
|
||||
def test_create_model_class_with_fields_missing():
|
||||
test_class = ActionNode.create_model_class("test_class", WRITE_TASKS_OUTPUT_MAPPING)
|
||||
assert test_class.__name__ == "test_class"
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
_ = test_class(**t_dict_min)
|
||||
|
||||
|
||||
def test_create_model_class_with_mapping():
|
||||
t = ActionNode.create_model_class("test_class_1", WRITE_TASKS_OUTPUT_MAPPING)
|
||||
t1 = t(**t_dict)
|
||||
value = t1.model_dump()["Task list"]
|
||||
assert value == ["game.py", "app.py", "static/css/styles.css", "static/js/script.js", "templates/index.html"]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_create_model_class()
|
||||
test_create_model_class_with_mapping()
|
||||
|
|
@ -1,50 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# coding: utf-8
|
||||
"""
|
||||
@Time : 2023/7/11 10:49
|
||||
@Author : chengmaoyu
|
||||
@File : test_action_output
|
||||
"""
|
||||
from typing import List, Tuple
|
||||
|
||||
from metagpt.actions import ActionOutput
|
||||
|
||||
t_dict = {"Required Python third-party packages": "\"\"\"\nflask==1.1.2\npygame==2.0.1\n\"\"\"\n",
|
||||
"Required Other language third-party packages": "\"\"\"\nNo third-party packages required for other languages.\n\"\"\"\n",
|
||||
"Full API spec": "\"\"\"\nopenapi: 3.0.0\ninfo:\n title: Web Snake Game API\n version: 1.0.0\npaths:\n /game:\n get:\n summary: Get the current game state\n responses:\n '200':\n description: A JSON object of the game state\n post:\n summary: Send a command to the game\n requestBody:\n required: true\n content:\n application/json:\n schema:\n type: object\n properties:\n command:\n type: string\n responses:\n '200':\n description: A JSON object of the updated game state\n\"\"\"\n",
|
||||
"Logic Analysis": [
|
||||
["app.py", "Main entry point for the Flask application. Handles HTTP requests and responses."],
|
||||
["game.py", "Contains the Game and Snake classes. Handles the game logic."],
|
||||
["static/js/script.js", "Handles user interactions and updates the game UI."],
|
||||
["static/css/styles.css", "Defines the styles for the game UI."],
|
||||
["templates/index.html", "The main page of the web application. Displays the game UI."]],
|
||||
"Task list": ["game.py", "app.py", "static/css/styles.css", "static/js/script.js", "templates/index.html"],
|
||||
"Shared Knowledge": "\"\"\"\n'game.py' contains the Game and Snake classes which are responsible for the game logic. The Game class uses an instance of the Snake class.\n\n'app.py' is the main entry point for the Flask application. It creates an instance of the Game class and handles HTTP requests and responses.\n\n'static/js/script.js' is responsible for handling user interactions and updating the game UI based on the game state returned by 'app.py'.\n\n'static/css/styles.css' defines the styles for the game UI.\n\n'templates/index.html' is the main page of the web application. It displays the game UI and loads 'static/js/script.js' and 'static/css/styles.css'.\n\"\"\"\n",
|
||||
"Anything UNCLEAR": "We need clarification on how the high score should be stored. Should it persist across sessions (stored in a database or a file) or should it reset every time the game is restarted? Also, should the game speed increase as the snake grows, or should it remain constant throughout the game?"}
|
||||
|
||||
WRITE_TASKS_OUTPUT_MAPPING = {
|
||||
"Required Python third-party packages": (str, ...),
|
||||
"Required Other language third-party packages": (str, ...),
|
||||
"Full API spec": (str, ...),
|
||||
"Logic Analysis": (List[Tuple[str, str]], ...),
|
||||
"Task list": (List[str], ...),
|
||||
"Shared Knowledge": (str, ...),
|
||||
"Anything UNCLEAR": (str, ...),
|
||||
}
|
||||
|
||||
|
||||
def test_create_model_class():
|
||||
test_class = ActionOutput.create_model_class("test_class", WRITE_TASKS_OUTPUT_MAPPING)
|
||||
assert test_class.__name__ == "test_class"
|
||||
|
||||
|
||||
def test_create_model_class_with_mapping():
|
||||
t = ActionOutput.create_model_class("test_class_1", WRITE_TASKS_OUTPUT_MAPPING)
|
||||
t1 = t(**t_dict)
|
||||
value = t1.dict()["Task list"]
|
||||
assert value == ["game.py", "app.py", "static/css/styles.css", "static/js/script.js", "templates/index.html"]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_create_model_class()
|
||||
test_create_model_class_with_mapping()
|
||||
|
|
@ -1,21 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/7/1 22:50
|
||||
@Author : alexanderwu
|
||||
@File : test_azure_tts.py
|
||||
"""
|
||||
from metagpt.actions.azure_tts import AzureTTS
|
||||
|
||||
|
||||
def test_azure_tts():
|
||||
azure_tts = AzureTTS("azure_tts")
|
||||
azure_tts.synthesize_speech(
|
||||
"zh-CN",
|
||||
"zh-CN-YunxiNeural",
|
||||
"Boy",
|
||||
"你好,我是卡卡",
|
||||
"output.wav")
|
||||
|
||||
# 运行需要先配置 SUBSCRIPTION_KEY
|
||||
# TODO: 这里如果要检验,还要额外加上对应的asr,才能确保前后生成是接近一致的,但现在还没有
|
||||
|
|
@ -1,54 +0,0 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.actions.clone_function import CloneFunction, run_function_code
|
||||
|
||||
|
||||
source_code = """
|
||||
import pandas as pd
|
||||
import ta
|
||||
|
||||
def user_indicator():
|
||||
# 读取股票数据
|
||||
stock_data = pd.read_csv('./tests/data/baba_stock.csv')
|
||||
stock_data.head()
|
||||
# 计算简单移动平均线
|
||||
stock_data['SMA'] = ta.trend.sma_indicator(stock_data['Close'], window=6)
|
||||
stock_data[['Date', 'Close', 'SMA']].head()
|
||||
# 计算布林带
|
||||
stock_data['bb_upper'], stock_data['bb_middle'], stock_data['bb_lower'] = ta.volatility.bollinger_hband_indicator(stock_data['Close'], window=20), ta.volatility.bollinger_mavg(stock_data['Close'], window=20), ta.volatility.bollinger_lband_indicator(stock_data['Close'], window=20)
|
||||
stock_data[['Date', 'Close', 'bb_upper', 'bb_middle', 'bb_lower']].head()
|
||||
"""
|
||||
|
||||
template_code = """
|
||||
def stock_indicator(stock_path: str, indicators=['Simple Moving Average', 'BollingerBands', 'MACD]) -> pd.DataFrame:
|
||||
import pandas as pd
|
||||
# here is your code.
|
||||
"""
|
||||
|
||||
|
||||
def get_expected_res():
|
||||
import pandas as pd
|
||||
import ta
|
||||
|
||||
# 读取股票数据
|
||||
stock_data = pd.read_csv('./tests/data/baba_stock.csv')
|
||||
stock_data.head()
|
||||
# 计算简单移动平均线
|
||||
stock_data['SMA'] = ta.trend.sma_indicator(stock_data['Close'], window=6)
|
||||
stock_data[['Date', 'Close', 'SMA']].head()
|
||||
# 计算布林带
|
||||
stock_data['bb_upper'], stock_data['bb_middle'], stock_data['bb_lower'] = ta.volatility.bollinger_hband_indicator(stock_data['Close'], window=20), ta.volatility.bollinger_mavg(stock_data['Close'], window=20), ta.volatility.bollinger_lband_indicator(stock_data['Close'], window=20)
|
||||
stock_data[['Date', 'Close', 'bb_upper', 'bb_middle', 'bb_lower']].head()
|
||||
return stock_data
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clone_function():
|
||||
clone = CloneFunction()
|
||||
code = await clone.run(template_code, source_code)
|
||||
assert 'def ' in code
|
||||
stock_path = './tests/data/baba_stock.csv'
|
||||
df, msg = run_function_code(code, 'stock_indicator', stock_path)
|
||||
assert not msg
|
||||
expected_df = get_expected_res()
|
||||
assert df.equals(expected_df)
|
||||
|
|
@ -4,17 +4,19 @@
|
|||
@Time : 2023/5/11 17:46
|
||||
@Author : alexanderwu
|
||||
@File : test_debug_error.py
|
||||
@Modifiled By: mashenquan, 2023-12-6. According to RFC 135
|
||||
"""
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.actions.debug_error import DebugError
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import TEST_CODES_FILE_REPO, TEST_OUTPUTS_FILE_REPO
|
||||
from metagpt.schema import RunCodeContext, RunCodeResult
|
||||
from metagpt.utils.file_repository import FileRepository
|
||||
|
||||
EXAMPLE_MSG_CONTENT = '''
|
||||
---
|
||||
## Development Code File Name
|
||||
player.py
|
||||
## Development Code
|
||||
```python
|
||||
CODE_CONTENT = '''
|
||||
from typing import List
|
||||
from deck import Deck
|
||||
from card import Card
|
||||
|
|
@ -58,12 +60,9 @@ class Player:
|
|||
if self.score > 21 and any(card.rank == 'A' for card in self.hand):
|
||||
self.score -= 10
|
||||
return self.score
|
||||
'''
|
||||
|
||||
```
|
||||
## Test File Name
|
||||
test_player.py
|
||||
## Test Code
|
||||
```python
|
||||
TEST_CONTENT = """
|
||||
import unittest
|
||||
from blackjack_game.player import Player
|
||||
from blackjack_game.deck import Deck
|
||||
|
|
@ -114,42 +113,52 @@ class TestPlayer(unittest.TestCase):
|
|||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
```
|
||||
## Running Command
|
||||
python tests/test_player.py
|
||||
## Running Output
|
||||
standard output: ;
|
||||
standard errors: ..F..
|
||||
======================================================================
|
||||
FAIL: test_player_calculate_score_with_multiple_aces (__main__.TestPlayer)
|
||||
----------------------------------------------------------------------
|
||||
Traceback (most recent call last):
|
||||
File "tests/test_player.py", line 46, in test_player_calculate_score_with_multiple_aces
|
||||
self.assertEqual(player.score, 12)
|
||||
AssertionError: 22 != 12
|
||||
"""
|
||||
|
||||
----------------------------------------------------------------------
|
||||
Ran 5 tests in 0.007s
|
||||
|
||||
FAILED (failures=1)
|
||||
;
|
||||
## instruction:
|
||||
The error is in the development code, specifically in the calculate_score method of the Player class. The method is not correctly handling the case where there are multiple Aces in the player's hand. The current implementation only subtracts 10 from the score once if the score is over 21 and there's an Ace in the hand. However, in the case of multiple Aces, it should subtract 10 for each Ace until the score is 21 or less.
|
||||
## File To Rewrite:
|
||||
player.py
|
||||
## Status:
|
||||
FAIL
|
||||
## Send To:
|
||||
Engineer
|
||||
---
|
||||
'''
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_debug_error():
|
||||
CONFIG.src_workspace = CONFIG.git_repo.workdir / uuid.uuid4().hex
|
||||
ctx = RunCodeContext(
|
||||
code_filename="player.py",
|
||||
test_filename="test_player.py",
|
||||
command=["python", "tests/test_player.py"],
|
||||
output_filename="output.log",
|
||||
)
|
||||
|
||||
debug_error = DebugError("debug_error")
|
||||
await FileRepository.save_file(filename=ctx.code_filename, content=CODE_CONTENT, relative_path=CONFIG.src_workspace)
|
||||
await FileRepository.save_file(filename=ctx.test_filename, content=TEST_CONTENT, relative_path=TEST_CODES_FILE_REPO)
|
||||
output_data = RunCodeResult(
|
||||
stdout=";",
|
||||
stderr="",
|
||||
summary="======================================================================\n"
|
||||
"FAIL: test_player_calculate_score_with_multiple_aces (__main__.TestPlayer)\n"
|
||||
"----------------------------------------------------------------------\n"
|
||||
"Traceback (most recent call last):\n"
|
||||
' File "tests/test_player.py", line 46, in test_player_calculate_score_'
|
||||
"with_multiple_aces\n"
|
||||
" self.assertEqual(player.score, 12)\nAssertionError: 22 != 12\n\n"
|
||||
"----------------------------------------------------------------------\n"
|
||||
"Ran 5 tests in 0.007s\n\nFAILED (failures=1)\n;\n",
|
||||
)
|
||||
await FileRepository.save_file(
|
||||
filename=ctx.output_filename, content=output_data.model_dump_json(), relative_path=TEST_OUTPUTS_FILE_REPO
|
||||
)
|
||||
debug_error = DebugError(context=ctx)
|
||||
|
||||
file_name, rewritten_code = await debug_error.run(context=EXAMPLE_MSG_CONTENT)
|
||||
rsp = await debug_error.run()
|
||||
|
||||
assert "class Player" in rewritten_code # rewrite the same class
|
||||
assert "while self.score > 21" in rewritten_code # a key logic to rewrite to (original one is "if self.score > 12")
|
||||
assert "class Player" in rsp # rewrite the same class
|
||||
# Problematic code:
|
||||
# ```
|
||||
# if self.score > 21 and any(card.rank == 'A' for card in self.hand):
|
||||
# self.score -= 10
|
||||
# ```
|
||||
# Should rewrite to (used "gpt-3.5-turbo-1106"):
|
||||
# ```
|
||||
# ace_count = sum(1 for card in self.hand if card.rank == 'A')
|
||||
# while self.score > 21 and ace_count > 0:
|
||||
# self.score -= 10
|
||||
# ace_count -= 1
|
||||
# ```
|
||||
assert "while self.score > 21" in rsp
|
||||
|
|
|
|||
|
|
@ -4,33 +4,27 @@
|
|||
@Time : 2023/5/11 19:26
|
||||
@Author : alexanderwu
|
||||
@File : test_design_api.py
|
||||
@Modifiled By: mashenquan, 2023-12-6. According to RFC 135
|
||||
"""
|
||||
import pytest
|
||||
|
||||
from metagpt.actions.design_api import WriteDesign
|
||||
from metagpt.const import PRDS_FILE_REPO
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import Message
|
||||
from tests.metagpt.actions.mock import PRD_SAMPLE
|
||||
from metagpt.utils.file_repository import FileRepository
|
||||
from tests.metagpt.actions.mock_markdown import PRD_SAMPLE
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_design_api():
|
||||
prd = "我们需要一个音乐播放器,它应该有播放、暂停、上一曲、下一曲等功能。"
|
||||
inputs = ["我们需要一个音乐播放器,它应该有播放、暂停、上一曲、下一曲等功能。", PRD_SAMPLE]
|
||||
for prd in inputs:
|
||||
await FileRepository.save_file("new_prd.txt", content=prd, relative_path=PRDS_FILE_REPO)
|
||||
|
||||
design_api = WriteDesign("design_api")
|
||||
design_api = WriteDesign()
|
||||
|
||||
result = await design_api.run([Message(content=prd, instruct_content=None)])
|
||||
logger.info(result)
|
||||
result = await design_api.run(Message(content=prd, instruct_content=None))
|
||||
logger.info(result)
|
||||
|
||||
assert result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_design_api_calculator():
|
||||
prd = PRD_SAMPLE
|
||||
|
||||
design_api = WriteDesign("design_api")
|
||||
result = await design_api.run([Message(content=prd, instruct_content=None)])
|
||||
logger.info(result)
|
||||
|
||||
assert result
|
||||
assert result
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ API列表:
|
|||
"""
|
||||
_ = "API设计看起来非常合理,满足了PRD中的所有需求。"
|
||||
|
||||
design_api_review = DesignReview("design_api_review")
|
||||
design_api_review = DesignReview()
|
||||
|
||||
result = await design_api_review.run(prd, api_design)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,23 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/9/13 00:26
|
||||
@Author : fisherdeng
|
||||
@File : test_detail_mining.py
|
||||
"""
|
||||
import pytest
|
||||
|
||||
from metagpt.actions.detail_mining import DetailMining
|
||||
from metagpt.logs import logger
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_detail_mining():
|
||||
topic = "如何做一个生日蛋糕"
|
||||
record = "我认为应该先准备好材料,然后再开始做蛋糕。"
|
||||
detail_mining = DetailMining("detail_mining")
|
||||
rsp = await detail_mining.run(topic=topic, record=record)
|
||||
logger.info(f"{rsp.content=}")
|
||||
|
||||
assert '##OUTPUT' in rsp.content
|
||||
assert '蛋糕' in rsp.content
|
||||
|
||||
17
tests/metagpt/actions/test_fix_bug.py
Normal file
17
tests/metagpt/actions/test_fix_bug.py
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/12/25 22:38
|
||||
@Author : alexanderwu
|
||||
@File : test_fix_bug.py
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.actions.fix_bug import FixBug
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fix_bug():
|
||||
fix_bug = FixBug()
|
||||
assert fix_bug.name == "FixBug"
|
||||
29
tests/metagpt/actions/test_generate_questions.py
Normal file
29
tests/metagpt/actions/test_generate_questions.py
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/9/13 00:26
|
||||
@Author : fisherdeng
|
||||
@File : test_generate_questions.py
|
||||
"""
|
||||
import pytest
|
||||
|
||||
from metagpt.actions.generate_questions import GenerateQuestions
|
||||
from metagpt.logs import logger
|
||||
|
||||
context = """
|
||||
## topic
|
||||
如何做一个生日蛋糕
|
||||
|
||||
## record
|
||||
我认为应该先准备好材料,然后再开始做蛋糕。
|
||||
"""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_questions():
|
||||
action = GenerateQuestions()
|
||||
rsp = await action.run(context)
|
||||
logger.info(f"{rsp.content=}")
|
||||
|
||||
assert "Questions" in rsp.content
|
||||
assert "1." in rsp.content
|
||||
|
|
@ -7,27 +7,25 @@
|
|||
@File : test_invoice_ocr.py
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
|
||||
from metagpt.actions.invoice_ocr import InvoiceOCR, GenerateTable, ReplyQuestion
|
||||
import pytest
|
||||
|
||||
from metagpt.actions.invoice_ocr import GenerateTable, InvoiceOCR, ReplyQuestion
|
||||
from metagpt.const import TEST_DATA_PATH
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"invoice_path",
|
||||
[
|
||||
"../../data/invoices/invoice-3.jpg",
|
||||
"../../data/invoices/invoice-4.zip",
|
||||
]
|
||||
Path("invoices/invoice-3.jpg"),
|
||||
Path("invoices/invoice-4.zip"),
|
||||
],
|
||||
)
|
||||
async def test_invoice_ocr(invoice_path: str):
|
||||
invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path))
|
||||
filename = os.path.basename(invoice_path)
|
||||
resp = await InvoiceOCR().run(file_path=Path(invoice_path), filename=filename)
|
||||
async def test_invoice_ocr(invoice_path: Path):
|
||||
invoice_path = TEST_DATA_PATH / invoice_path
|
||||
resp = await InvoiceOCR().run(file_path=Path(invoice_path))
|
||||
assert isinstance(resp, list)
|
||||
|
||||
|
||||
|
|
@ -35,38 +33,29 @@ async def test_invoice_ocr(invoice_path: str):
|
|||
@pytest.mark.parametrize(
|
||||
("invoice_path", "expected_result"),
|
||||
[
|
||||
(
|
||||
"../../data/invoices/invoice-1.pdf",
|
||||
[
|
||||
{
|
||||
"收款人": "小明",
|
||||
"城市": "深圳市",
|
||||
"总费用/元": "412.00",
|
||||
"开票日期": "2023年02月03日"
|
||||
}
|
||||
]
|
||||
),
|
||||
]
|
||||
(Path("invoices/invoice-1.pdf"), {"收款人": "小明", "城市": "深圳", "总费用/元": 412.00, "开票日期": "2023年02月03日"}),
|
||||
],
|
||||
)
|
||||
async def test_generate_table(invoice_path: str, expected_result: list[dict]):
|
||||
invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path))
|
||||
filename = os.path.basename(invoice_path)
|
||||
ocr_result = await InvoiceOCR().run(file_path=Path(invoice_path), filename=filename)
|
||||
async def test_generate_table(invoice_path: Path, expected_result: dict):
|
||||
invoice_path = TEST_DATA_PATH / invoice_path
|
||||
filename = invoice_path.name
|
||||
ocr_result = await InvoiceOCR().run(file_path=Path(invoice_path))
|
||||
table_data = await GenerateTable().run(ocr_results=ocr_result, filename=filename)
|
||||
assert table_data == expected_result
|
||||
assert isinstance(table_data, list)
|
||||
table_data = table_data[0]
|
||||
assert expected_result["收款人"] == table_data["收款人"]
|
||||
assert expected_result["城市"] in table_data["城市"]
|
||||
assert float(expected_result["总费用/元"]) == float(table_data["总费用/元"])
|
||||
assert expected_result["开票日期"] == table_data["开票日期"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
("invoice_path", "query", "expected_result"),
|
||||
[
|
||||
("../../data/invoices/invoice-1.pdf", "Invoicing date", "2023年02月03日")
|
||||
]
|
||||
[(Path("invoices/invoice-1.pdf"), "Invoicing date", "2023年02月03日")],
|
||||
)
|
||||
async def test_reply_question(invoice_path: str, query: dict, expected_result: str):
|
||||
invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path))
|
||||
filename = os.path.basename(invoice_path)
|
||||
ocr_result = await InvoiceOCR().run(file_path=Path(invoice_path), filename=filename)
|
||||
async def test_reply_question(invoice_path: Path, query: dict, expected_result: str):
|
||||
invoice_path = TEST_DATA_PATH / invoice_path
|
||||
ocr_result = await InvoiceOCR().run(file_path=Path(invoice_path))
|
||||
result = await ReplyQuestion().run(query=query, ocr_result=ocr_result)
|
||||
assert expected_result in result
|
||||
|
||||
|
|
|
|||
30
tests/metagpt/actions/test_prepare_documents.py
Normal file
30
tests/metagpt/actions/test_prepare_documents.py
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/12/6
|
||||
@Author : mashenquan
|
||||
@File : test_prepare_documents.py
|
||||
@Desc: Unit test for prepare_documents.py
|
||||
"""
|
||||
import pytest
|
||||
|
||||
from metagpt.actions.prepare_documents import PrepareDocuments
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import DOCS_FILE_REPO, REQUIREMENT_FILENAME
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.file_repository import FileRepository
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prepare_documents():
|
||||
msg = Message(content="New user requirements balabala...")
|
||||
|
||||
if CONFIG.git_repo:
|
||||
CONFIG.git_repo.delete_repository()
|
||||
CONFIG.git_repo = None
|
||||
|
||||
await PrepareDocuments().run(with_messages=[msg])
|
||||
assert CONFIG.git_repo
|
||||
doc = await FileRepository.get_file(filename=REQUIREMENT_FILENAME, relative_path=DOCS_FILE_REPO)
|
||||
assert doc
|
||||
assert doc.content == msg.content
|
||||
21
tests/metagpt/actions/test_prepare_interview.py
Normal file
21
tests/metagpt/actions/test_prepare_interview.py
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/9/13 00:26
|
||||
@Author : fisherdeng
|
||||
@File : test_generate_questions.py
|
||||
"""
|
||||
import pytest
|
||||
|
||||
from metagpt.actions.prepare_interview import PrepareInterview
|
||||
from metagpt.logs import logger
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prepare_interview():
|
||||
action = PrepareInterview()
|
||||
rsp = await action.run("I just graduated and hope to find a job as a Python engineer")
|
||||
logger.info(f"{rsp.content=}")
|
||||
|
||||
assert "Questions" in rsp.content
|
||||
assert "1." in rsp.content
|
||||
|
|
@ -6,10 +6,26 @@
|
|||
@File : test_project_management.py
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
class TestCreateProjectPlan:
|
||||
pass
|
||||
from metagpt.actions.project_management import WriteTasks
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import PRDS_FILE_REPO, SYSTEM_DESIGN_FILE_REPO
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.file_repository import FileRepository
|
||||
from tests.metagpt.actions.mock_json import DESIGN, PRD
|
||||
|
||||
|
||||
class TestAssignTasks:
|
||||
pass
|
||||
@pytest.mark.asyncio
|
||||
async def test_design_api():
|
||||
await FileRepository.save_file("1.txt", content=str(PRD), relative_path=PRDS_FILE_REPO)
|
||||
await FileRepository.save_file("1.txt", content=str(DESIGN), relative_path=SYSTEM_DESIGN_FILE_REPO)
|
||||
logger.info(CONFIG.git_repo)
|
||||
|
||||
action = WriteTasks()
|
||||
|
||||
result = await action.run(Message(content="", instruct_content=None))
|
||||
logger.info(result)
|
||||
|
||||
assert result
|
||||
|
|
|
|||
57
tests/metagpt/actions/test_rebuild_class_view.py
Normal file
57
tests/metagpt/actions/test_rebuild_class_view.py
Normal file
|
|
@ -0,0 +1,57 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/12/20
|
||||
@Author : mashenquan
|
||||
@File : test_rebuild_class_view.py
|
||||
@Desc : Unit tests for rebuild_class_view.py
|
||||
"""
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.actions.rebuild_class_view import RebuildClassView
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import GRAPH_REPO_FILE_REPO
|
||||
from metagpt.llm import LLM
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rebuild():
|
||||
action = RebuildClassView(
|
||||
name="RedBean", context=str(Path(__file__).parent.parent.parent.parent / "metagpt"), llm=LLM()
|
||||
)
|
||||
await action.run()
|
||||
graph_file_repo = CONFIG.git_repo.new_file_repository(relative_path=GRAPH_REPO_FILE_REPO)
|
||||
assert graph_file_repo.changed_files
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("path", "direction", "diff", "want"),
|
||||
[
|
||||
("metagpt/startup.py", "=", ".", "metagpt/startup.py"),
|
||||
("metagpt/startup.py", "+", "MetaGPT", "MetaGPT/metagpt/startup.py"),
|
||||
("metagpt/startup.py", "-", "metagpt", "startup.py"),
|
||||
],
|
||||
)
|
||||
def test_align_path(path, direction, diff, want):
|
||||
res = RebuildClassView._align_root(path=path, direction=direction, diff_path=diff)
|
||||
assert res == want
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("path_root", "package_root", "want_direction", "want_diff"),
|
||||
[
|
||||
("/Users/x/github/MetaGPT/metagpt", "/Users/x/github/MetaGPT/metagpt", "=", "."),
|
||||
("/Users/x/github/MetaGPT", "/Users/x/github/MetaGPT/metagpt", "-", "metagpt"),
|
||||
("/Users/x/github/MetaGPT/metagpt", "/Users/x/github/MetaGPT", "+", "metagpt"),
|
||||
],
|
||||
)
|
||||
def test_diff_path(path_root, package_root, want_direction, want_diff):
|
||||
direction, diff = RebuildClassView._diff_path(path_root=Path(path_root), package_root=Path(package_root))
|
||||
assert direction == want_direction
|
||||
assert diff == want_diff
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
55
tests/metagpt/actions/test_rebuild_sequence_view.py
Normal file
55
tests/metagpt/actions/test_rebuild_sequence_view.py
Normal file
|
|
@ -0,0 +1,55 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2024/1/4
|
||||
@Author : mashenquan
|
||||
@File : test_rebuild_sequence_view.py
|
||||
"""
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.actions.rebuild_sequence_view import RebuildSequenceView
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import GRAPH_REPO_FILE_REPO
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.utils.common import aread
|
||||
from metagpt.utils.file_repository import FileRepository
|
||||
from metagpt.utils.git_repository import ChangeType
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rebuild():
|
||||
# Mock
|
||||
data = await aread(filename=Path(__file__).parent / "../../data/graph_db/networkx.json")
|
||||
graph_db_filename = Path(CONFIG.git_repo.workdir.name).with_suffix(".json")
|
||||
await FileRepository.save_file(
|
||||
filename=str(graph_db_filename),
|
||||
relative_path=GRAPH_REPO_FILE_REPO,
|
||||
content=data,
|
||||
)
|
||||
CONFIG.git_repo.add_change({f"{GRAPH_REPO_FILE_REPO}/{graph_db_filename}": ChangeType.UNTRACTED})
|
||||
CONFIG.git_repo.commit("commit1")
|
||||
|
||||
action = RebuildSequenceView(
|
||||
name="RedBean", context=str(Path(__file__).parent.parent.parent.parent / "metagpt"), llm=LLM()
|
||||
)
|
||||
await action.run()
|
||||
graph_file_repo = CONFIG.git_repo.new_file_repository(relative_path=GRAPH_REPO_FILE_REPO)
|
||||
assert graph_file_repo.changed_files
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("root", "pathname", "want"),
|
||||
[
|
||||
(Path(__file__).parent.parent.parent, "/".join(__file__.split("/")[-2:]), Path(__file__)),
|
||||
(Path(__file__).parent.parent.parent, "f/g.txt", None),
|
||||
],
|
||||
)
|
||||
def test_get_full_filename(root, pathname, want):
|
||||
res = RebuildSequenceView._get_full_filename(root=root, pathname=pathname)
|
||||
assert res == want
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
117
tests/metagpt/actions/test_research.py
Normal file
117
tests/metagpt/actions/test_research.py
Normal file
|
|
@ -0,0 +1,117 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/12/28
|
||||
@Author : mashenquan
|
||||
@File : test_research.py
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.actions import research
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_collect_links(mocker):
|
||||
async def mock_llm_ask(self, prompt: str, system_msgs):
|
||||
if "Please provide up to 2 necessary keywords" in prompt:
|
||||
return '["metagpt", "llm"]'
|
||||
|
||||
elif "Provide up to 4 queries related to your research topic" in prompt:
|
||||
return (
|
||||
'["MetaGPT use cases", "The roadmap of MetaGPT", '
|
||||
'"The function of MetaGPT", "What llm MetaGPT support"]'
|
||||
)
|
||||
elif "sort the remaining search results" in prompt:
|
||||
return "[1,2]"
|
||||
|
||||
mocker.patch("metagpt.provider.base_llm.BaseLLM.aask", mock_llm_ask)
|
||||
resp = await research.CollectLinks().run("The application of MetaGPT")
|
||||
for i in ["MetaGPT use cases", "The roadmap of MetaGPT", "The function of MetaGPT", "What llm MetaGPT support"]:
|
||||
assert i in resp
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_collect_links_with_rank_func(mocker):
|
||||
rank_before = []
|
||||
rank_after = []
|
||||
url_per_query = 4
|
||||
|
||||
def rank_func(results):
|
||||
results = results[:url_per_query]
|
||||
rank_before.append(results)
|
||||
results = results[::-1]
|
||||
rank_after.append(results)
|
||||
return results
|
||||
|
||||
mocker.patch("metagpt.provider.base_llm.BaseLLM.aask", mock_collect_links_llm_ask)
|
||||
resp = await research.CollectLinks(rank_func=rank_func).run("The application of MetaGPT")
|
||||
for x, y, z in zip(rank_before, rank_after, resp.values()):
|
||||
assert x[::-1] == y
|
||||
assert [i["link"] for i in y] == z
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_web_browse_and_summarize(mocker):
|
||||
async def mock_llm_ask(*args, **kwargs):
|
||||
return "metagpt"
|
||||
|
||||
mocker.patch("metagpt.provider.base_llm.BaseLLM.aask", mock_llm_ask)
|
||||
url = "https://github.com/geekan/MetaGPT"
|
||||
url2 = "https://github.com/trending"
|
||||
query = "What's new in metagpt"
|
||||
resp = await research.WebBrowseAndSummarize().run(url, query=query)
|
||||
|
||||
assert len(resp) == 1
|
||||
assert url in resp
|
||||
assert resp[url] == "metagpt"
|
||||
|
||||
resp = await research.WebBrowseAndSummarize().run(url, url2, query=query)
|
||||
assert len(resp) == 2
|
||||
|
||||
async def mock_llm_ask(*args, **kwargs):
|
||||
return "Not relevant."
|
||||
|
||||
mocker.patch("metagpt.provider.base_llm.BaseLLM.aask", mock_llm_ask)
|
||||
resp = await research.WebBrowseAndSummarize().run(url, query=query)
|
||||
|
||||
assert len(resp) == 1
|
||||
assert url in resp
|
||||
assert resp[url] is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_conduct_research(mocker):
|
||||
data = None
|
||||
|
||||
async def mock_llm_ask(*args, **kwargs):
|
||||
nonlocal data
|
||||
data = f"# Research Report\n## Introduction\n{args} {kwargs}"
|
||||
return data
|
||||
|
||||
mocker.patch("metagpt.provider.base_llm.BaseLLM.aask", mock_llm_ask)
|
||||
content = (
|
||||
"MetaGPT takes a one line requirement as input and "
|
||||
"outputs user stories / competitive analysis / requirements / data structures / APIs / documents, etc."
|
||||
)
|
||||
|
||||
resp = await research.ConductResearch().run("The application of MetaGPT", content)
|
||||
assert resp == data
|
||||
|
||||
|
||||
async def mock_collect_links_llm_ask(self, prompt: str, system_msgs):
|
||||
if "Please provide up to 2 necessary keywords" in prompt:
|
||||
return '["metagpt", "llm"]'
|
||||
|
||||
elif "Provide up to 4 queries related to your research topic" in prompt:
|
||||
return (
|
||||
'["MetaGPT use cases", "The roadmap of MetaGPT", ' '"The function of MetaGPT", "What llm MetaGPT support"]'
|
||||
)
|
||||
elif "sort the remaining search results" in prompt:
|
||||
return "[1,2]"
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
@ -4,21 +4,23 @@
|
|||
@Time : 2023/5/11 17:46
|
||||
@Author : alexanderwu
|
||||
@File : test_run_code.py
|
||||
@Modifiled By: mashenquan, 2023-12-6. According to RFC 135
|
||||
"""
|
||||
import pytest
|
||||
|
||||
from metagpt.actions.run_code import RunCode
|
||||
from metagpt.schema import RunCodeContext
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_text():
|
||||
result, errs = await RunCode.run_text("result = 1 + 1")
|
||||
assert result == 2
|
||||
assert errs == ""
|
||||
out, err = await RunCode.run_text("result = 1 + 1")
|
||||
assert out == 2
|
||||
assert err == ""
|
||||
|
||||
result, errs = await RunCode.run_text("result = 1 / 0")
|
||||
assert result == ""
|
||||
assert "ZeroDivisionError" in errs
|
||||
out, err = await RunCode.run_text("result = 1 / 0")
|
||||
assert out == ""
|
||||
assert "division by zero" in err
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -35,37 +37,29 @@ async def test_run_script():
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run():
|
||||
action = RunCode()
|
||||
result = await action.run(mode="text", code="print('Hello, World')")
|
||||
assert "PASS" in result
|
||||
|
||||
result = await action.run(
|
||||
mode="script",
|
||||
code="echo 'Hello World'",
|
||||
code_file_name="",
|
||||
test_code="",
|
||||
test_file_name="",
|
||||
command=["echo", "Hello World"],
|
||||
working_directory=".",
|
||||
additional_python_paths=[],
|
||||
)
|
||||
assert "PASS" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_failure():
|
||||
action = RunCode()
|
||||
result = await action.run(mode="text", code="result = 1 / 0")
|
||||
assert "FAIL" in result
|
||||
|
||||
result = await action.run(
|
||||
mode="script",
|
||||
code='python -c "print(1/0)"',
|
||||
code_file_name="",
|
||||
test_code="",
|
||||
test_file_name="",
|
||||
command=["python", "-c", "print(1/0)"],
|
||||
working_directory=".",
|
||||
additional_python_paths=[],
|
||||
)
|
||||
assert "FAIL" in result
|
||||
inputs = [
|
||||
(RunCodeContext(mode="text", code_filename="a.txt", code="print('Hello, World')"), "PASS"),
|
||||
(
|
||||
RunCodeContext(
|
||||
mode="script",
|
||||
code_filename="a.sh",
|
||||
code="echo 'Hello World'",
|
||||
command=["echo", "Hello World"],
|
||||
working_directory=".",
|
||||
),
|
||||
"PASS",
|
||||
),
|
||||
(
|
||||
RunCodeContext(
|
||||
mode="script",
|
||||
code_filename="a.py",
|
||||
code='python -c "print(1/0)"',
|
||||
command=["python", "-c", "print(1/0)"],
|
||||
working_directory=".",
|
||||
),
|
||||
"FAIL",
|
||||
),
|
||||
]
|
||||
for ctx, result in inputs:
|
||||
rsp = await RunCode(context=ctx).run()
|
||||
assert result in rsp.summary
|
||||
|
|
|
|||
91
tests/metagpt/actions/test_skill_action.py
Normal file
91
tests/metagpt/actions/test_skill_action.py
Normal file
|
|
@ -0,0 +1,91 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/9/19
|
||||
@Author : mashenquan
|
||||
@File : test_skill_action.py
|
||||
@Desc : Unit tests.
|
||||
"""
|
||||
import pytest
|
||||
|
||||
from metagpt.actions.skill_action import ArgumentsParingAction, SkillAction
|
||||
from metagpt.learn.skill_loader import Example, Parameter, Returns, Skill
|
||||
|
||||
|
||||
class TestSkillAction:
|
||||
skill = Skill(
|
||||
name="text_to_image",
|
||||
description="Create a drawing based on the text.",
|
||||
id="text_to_image.text_to_image",
|
||||
x_prerequisite={
|
||||
"configurations": {
|
||||
"OPENAI_API_KEY": {
|
||||
"type": "string",
|
||||
"description": "OpenAI API key, For more details, checkout: `https://platform.openai.com/account/api-keys`",
|
||||
},
|
||||
"METAGPT_TEXT_TO_IMAGE_MODEL_URL": {"type": "string", "description": "Model url."},
|
||||
},
|
||||
"required": {"oneOf": ["OPENAI_API_KEY", "METAGPT_TEXT_TO_IMAGE_MODEL_URL"]},
|
||||
},
|
||||
parameters={
|
||||
"text": Parameter(type="string", description="The text used for image conversion."),
|
||||
"size_type": Parameter(type="string", description="size type"),
|
||||
},
|
||||
examples=[
|
||||
Example(ask="Draw a girl", answer='text_to_image(text="Draw a girl", size_type="512x512")'),
|
||||
Example(ask="Draw an apple", answer='text_to_image(text="Draw an apple", size_type="512x512")'),
|
||||
],
|
||||
returns=Returns(type="string", format="base64"),
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parser(self):
|
||||
args = ArgumentsParingAction.parse_arguments(
|
||||
skill_name="text_to_image", txt='`text_to_image(text="Draw an apple", size_type="512x512")`'
|
||||
)
|
||||
assert args.get("text") == "Draw an apple"
|
||||
assert args.get("size_type") == "512x512"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parser_action(self, mocker):
|
||||
# mock
|
||||
mocker.patch("metagpt.learn.text_to_image", return_value="https://mock.com/xxx")
|
||||
|
||||
parser_action = ArgumentsParingAction(skill=self.skill, ask="Draw an apple")
|
||||
rsp = await parser_action.run()
|
||||
assert rsp
|
||||
assert parser_action.args
|
||||
assert parser_action.args.get("text") == "Draw an apple"
|
||||
assert parser_action.args.get("size_type") == "512x512"
|
||||
|
||||
action = SkillAction(skill=self.skill, args=parser_action.args)
|
||||
rsp = await action.run()
|
||||
assert rsp
|
||||
assert "image/png;base64," in rsp.content or "http" in rsp.content
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("skill_name", "txt", "want"),
|
||||
[
|
||||
("skill1", 'skill1(a="1", b="2")', {"a": "1", "b": "2"}),
|
||||
("skill1", '(a="1", b="2")', None),
|
||||
("skill1", 'skill1(a="1", b="2"', None),
|
||||
],
|
||||
)
|
||||
def test_parse_arguments(self, skill_name, txt, want):
|
||||
args = ArgumentsParingAction.parse_arguments(skill_name, txt)
|
||||
assert args == want
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_and_call_function_error(self):
|
||||
with pytest.raises(ValueError):
|
||||
await SkillAction.find_and_call_function("dummy_call", {"a": 1})
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skill_action_error(self):
|
||||
action = SkillAction(skill=self.skill, args={})
|
||||
rsp = await action.run()
|
||||
assert "Error" in rsp.content
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
195
tests/metagpt/actions/test_summarize_code.py
Normal file
195
tests/metagpt/actions/test_summarize_code.py
Normal file
|
|
@ -0,0 +1,195 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/5/11 17:46
|
||||
@Author : mashenquan
|
||||
@File : test_summarize_code.py
|
||||
@Modifiled By: mashenquan, 2023-12-6. Unit test for summarize_code.py
|
||||
"""
|
||||
import pytest
|
||||
|
||||
from metagpt.actions.summarize_code import SummarizeCode
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import SYSTEM_DESIGN_FILE_REPO, TASK_FILE_REPO
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import CodeSummarizeContext
|
||||
from metagpt.utils.file_repository import FileRepository
|
||||
|
||||
DESIGN_CONTENT = """
|
||||
{"Implementation approach": "To develop this snake game, we will use the Python language and choose the Pygame library. Pygame is an open-source Python module collection specifically designed for writing video games. It provides functionalities such as displaying images and playing sounds, making it suitable for creating intuitive and responsive user interfaces. We will ensure efficient game logic to prevent any delays during gameplay. The scoring system will be simple, with the snake gaining points for each food it eats. We will use Pygame's event handling system to implement pause and resume functionality, as well as high-score tracking. The difficulty will increase by speeding up the snake's movement. In the initial version, we will focus on single-player mode and consider adding multiplayer mode and customizable skins in future updates. Based on the new requirement, we will also add a moving obstacle that appears randomly. If the snake eats this obstacle, the game will end. If the snake does not eat the obstacle, it will disappear after 5 seconds. For this, we need to add mechanisms for obstacle generation, movement, and disappearance in the game logic.", "Project_name": "snake_game", "File list": ["main.py", "game.py", "snake.py", "food.py", "obstacle.py", "scoreboard.py", "constants.py", "assets/styles.css", "assets/index.html"], "Data structures and interfaces": "```mermaid\n classDiagram\n class Game{\n +int score\n +int speed\n +bool game_over\n +bool paused\n +Snake snake\n +Food food\n +Obstacle obstacle\n +Scoreboard scoreboard\n +start_game() void\n +pause_game() void\n +resume_game() void\n +end_game() void\n +increase_difficulty() void\n +update() void\n +render() void\n Game()\n }\n class Snake{\n +list body_parts\n +str direction\n +bool grow\n +move() void\n +grow() void\n +check_collision() bool\n Snake()\n }\n class Food{\n +tuple position\n +spawn() void\n Food()\n }\n class Obstacle{\n +tuple position\n +int lifetime\n +bool active\n +spawn() void\n +move() void\n +check_collision() bool\n +disappear() void\n Obstacle()\n }\n class Scoreboard{\n +int high_score\n +update_score(int) void\n +reset_score() void\n +load_high_score() void\n +save_high_score() void\n Scoreboard()\n }\n class Constants{\n }\n Game \"1\" -- \"1\" Snake: has\n Game \"1\" -- \"1\" Food: has\n Game \"1\" -- \"1\" Obstacle: has\n Game \"1\" -- \"1\" Scoreboard: has\n ```", "Program call flow": "```sequenceDiagram\n participant M as Main\n participant G as Game\n participant S as Snake\n participant F as Food\n participant O as Obstacle\n participant SB as Scoreboard\n M->>G: start_game()\n loop game loop\n G->>S: move()\n G->>S: check_collision()\n G->>F: spawn()\n G->>O: spawn()\n G->>O: move()\n G->>O: check_collision()\n G->>O: disappear()\n G->>SB: update_score(score)\n G->>G: update()\n G->>G: render()\n alt if paused\n M->>G: pause_game()\n M->>G: resume_game()\n end\n alt if game_over\n G->>M: end_game()\n end\n end\n```", "Anything UNCLEAR": "There is no need for further clarification as the requirements are already clear."}
|
||||
"""
|
||||
|
||||
TASK_CONTENT = """
|
||||
{"Required Python third-party packages": ["pygame==2.0.1"], "Required Other language third-party packages": ["No third-party packages required for other languages."], "Full API spec": "\n openapi: 3.0.0\n info:\n title: Snake Game API\n version: \"1.0.0\"\n paths:\n /start:\n get:\n summary: Start the game\n responses:\n '200':\n description: Game started successfully\n /pause:\n get:\n summary: Pause the game\n responses:\n '200':\n description: Game paused successfully\n /resume:\n get:\n summary: Resume the game\n responses:\n '200':\n description: Game resumed successfully\n /end:\n get:\n summary: End the game\n responses:\n '200':\n description: Game ended successfully\n /score:\n get:\n summary: Get the current score\n responses:\n '200':\n description: Current score retrieved successfully\n /highscore:\n get:\n summary: Get the high score\n responses:\n '200':\n description: High score retrieved successfully\n components: {}\n ", "Logic Analysis": [["constants.py", "Contains all the constant values like screen size, colors, game speeds, etc. This should be implemented first as it provides the base values for other components."], ["snake.py", "Contains the Snake class with methods for movement, growth, and collision detection. It is dependent on constants.py for configuration values."], ["food.py", "Contains the Food class responsible for spawning food items on the screen. It is dependent on constants.py for configuration values."], ["obstacle.py", "Contains the Obstacle class with methods for spawning, moving, and disappearing of obstacles, as well as collision detection with the snake. It is dependent on constants.py for configuration values."], ["scoreboard.py", "Contains the Scoreboard class for updating, resetting, loading, and saving high scores. It may use constants.py for configuration values and depends on the game's scoring logic."], ["game.py", "Contains the main Game class which includes the game loop and methods for starting, pausing, resuming, and ending the game. It is dependent on snake.py, food.py, obstacle.py, and scoreboard.py."], ["main.py", "The entry point of the game that initializes the game and starts the game loop. It is dependent on game.py."]], "Task list": ["constants.py", "snake.py", "food.py", "obstacle.py", "scoreboard.py", "game.py", "main.py"], "Shared Knowledge": "\n 'constants.py' should contain all the necessary configurations for the game, such as screen dimensions, color definitions, and speed settings. These constants will be used across multiple files, ensuring consistency and ease of updates. Ensure that the Pygame library is initialized correctly in 'main.py' before starting the game loop. Also, make sure that the game's state is managed properly when pausing and resuming the game.\n ", "Anything UNCLEAR": "The interaction between the 'obstacle.py' and the game loop needs to be clearly defined to ensure obstacles appear and disappear correctly. The lifetime of the obstacle and its random movement should be implemented in a way that does not interfere with the game's performance."}
|
||||
"""
|
||||
|
||||
FOOD_PY = """
|
||||
## food.py
|
||||
import random
|
||||
|
||||
class Food:
|
||||
def __init__(self):
|
||||
self.position = (0, 0)
|
||||
|
||||
def generate(self):
|
||||
x = random.randint(0, 9)
|
||||
y = random.randint(0, 9)
|
||||
self.position = (x, y)
|
||||
|
||||
def get_position(self):
|
||||
return self.position
|
||||
|
||||
"""
|
||||
|
||||
GAME_PY = """
|
||||
## game.py
|
||||
import pygame
|
||||
from snake import Snake
|
||||
from food import Food
|
||||
|
||||
class Game:
|
||||
def __init__(self):
|
||||
self.score = 0
|
||||
self.level = 1
|
||||
self.snake = Snake()
|
||||
self.food = Food()
|
||||
|
||||
def start_game(self):
|
||||
pygame.init()
|
||||
self.initialize_game()
|
||||
self.game_loop()
|
||||
|
||||
def initialize_game(self):
|
||||
self.score = 0
|
||||
self.level = 1
|
||||
self.snake.reset()
|
||||
self.food.generate()
|
||||
|
||||
def game_loop(self):
|
||||
game_over = False
|
||||
|
||||
while not game_over:
|
||||
self.update()
|
||||
self.draw()
|
||||
self.handle_events()
|
||||
self.check_collision()
|
||||
self.increase_score()
|
||||
self.increase_level()
|
||||
|
||||
if self.snake.is_collision():
|
||||
game_over = True
|
||||
self.game_over()
|
||||
|
||||
def update(self):
|
||||
self.snake.move()
|
||||
|
||||
def draw(self):
|
||||
self.snake.draw()
|
||||
self.food.draw()
|
||||
|
||||
def handle_events(self):
|
||||
for event in pygame.event.get():
|
||||
if event.type == pygame.QUIT:
|
||||
pygame.quit()
|
||||
quit()
|
||||
elif event.type == pygame.KEYDOWN:
|
||||
if event.key == pygame.K_UP:
|
||||
self.snake.change_direction("UP")
|
||||
elif event.key == pygame.K_DOWN:
|
||||
self.snake.change_direction("DOWN")
|
||||
elif event.key == pygame.K_LEFT:
|
||||
self.snake.change_direction("LEFT")
|
||||
elif event.key == pygame.K_RIGHT:
|
||||
self.snake.change_direction("RIGHT")
|
||||
|
||||
def check_collision(self):
|
||||
if self.snake.get_head() == self.food.get_position():
|
||||
self.snake.grow()
|
||||
self.food.generate()
|
||||
|
||||
def increase_score(self):
|
||||
self.score += 1
|
||||
|
||||
def increase_level(self):
|
||||
if self.score % 10 == 0:
|
||||
self.level += 1
|
||||
|
||||
def game_over(self):
|
||||
print("Game Over")
|
||||
self.initialize_game()
|
||||
|
||||
"""
|
||||
|
||||
MAIN_PY = """
|
||||
## main.py
|
||||
import pygame
|
||||
from game import Game
|
||||
|
||||
def main():
|
||||
pygame.init()
|
||||
game = Game()
|
||||
game.start_game()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
"""
|
||||
|
||||
SNAKE_PY = """
|
||||
## snake.py
|
||||
import pygame
|
||||
|
||||
class Snake:
|
||||
def __init__(self):
|
||||
self.body = [(0, 0)]
|
||||
self.direction = (1, 0)
|
||||
|
||||
def move(self):
|
||||
head = self.body[0]
|
||||
dx, dy = self.direction
|
||||
new_head = (head[0] + dx, head[1] + dy)
|
||||
self.body.insert(0, new_head)
|
||||
self.body.pop()
|
||||
|
||||
def change_direction(self, direction):
|
||||
if direction == "UP":
|
||||
self.direction = (0, -1)
|
||||
elif direction == "DOWN":
|
||||
self.direction = (0, 1)
|
||||
elif direction == "LEFT":
|
||||
self.direction = (-1, 0)
|
||||
elif direction == "RIGHT":
|
||||
self.direction = (1, 0)
|
||||
|
||||
def grow(self):
|
||||
tail = self.body[-1]
|
||||
dx, dy = self.direction
|
||||
new_tail = (tail[0] - dx, tail[1] - dy)
|
||||
self.body.append(new_tail)
|
||||
|
||||
def get_head(self):
|
||||
return self.body[0]
|
||||
|
||||
def get_body(self):
|
||||
return self.body[1:]
|
||||
|
||||
"""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_summarize_code():
|
||||
CONFIG.src_workspace = CONFIG.git_repo.workdir / "src"
|
||||
await FileRepository.save_file(filename="1.json", relative_path=SYSTEM_DESIGN_FILE_REPO, content=DESIGN_CONTENT)
|
||||
await FileRepository.save_file(filename="1.json", relative_path=TASK_FILE_REPO, content=TASK_CONTENT)
|
||||
await FileRepository.save_file(filename="food.py", relative_path=CONFIG.src_workspace, content=FOOD_PY)
|
||||
await FileRepository.save_file(filename="game.py", relative_path=CONFIG.src_workspace, content=GAME_PY)
|
||||
await FileRepository.save_file(filename="main.py", relative_path=CONFIG.src_workspace, content=MAIN_PY)
|
||||
await FileRepository.save_file(filename="snake.py", relative_path=CONFIG.src_workspace, content=SNAKE_PY)
|
||||
|
||||
src_file_repo = CONFIG.git_repo.new_file_repository(relative_path=CONFIG.src_workspace)
|
||||
all_files = src_file_repo.all_files
|
||||
ctx = CodeSummarizeContext(design_filename="1.json", task_filename="1.json", codes_filenames=all_files)
|
||||
action = SummarizeCode(context=ctx)
|
||||
rsp = await action.run()
|
||||
assert rsp
|
||||
logger.info(rsp)
|
||||
51
tests/metagpt/actions/test_talk_action.py
Normal file
51
tests/metagpt/actions/test_talk_action.py
Normal file
|
|
@ -0,0 +1,51 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/12/28
|
||||
@Author : mashenquan
|
||||
@File : test_talk_action.py
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.actions.talk_action import TalkAction
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.schema import Message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
("agent_description", "language", "context", "knowledge", "history_summary"),
|
||||
[
|
||||
(
|
||||
"mathematician",
|
||||
"English",
|
||||
"How old is Susie?",
|
||||
"Susie is a girl born in 2011/11/14. Today is 2023/12/3",
|
||||
"balabala... (useless words)",
|
||||
),
|
||||
(
|
||||
"mathematician",
|
||||
"Chinese",
|
||||
"Does Susie have an apple?",
|
||||
"Susie is a girl born in 2011/11/14. Today is 2023/12/3",
|
||||
"Susie had an apple, and she ate it right now",
|
||||
),
|
||||
],
|
||||
)
|
||||
async def test_prompt(agent_description, language, context, knowledge, history_summary):
|
||||
# Prerequisites
|
||||
CONFIG.agent_description = agent_description
|
||||
CONFIG.language = language
|
||||
|
||||
action = TalkAction(context=context, knowledge=knowledge, history_summary=history_summary)
|
||||
assert "{" not in action.prompt
|
||||
assert "{" not in action.prompt_gpt4
|
||||
|
||||
rsp = await action.run()
|
||||
assert rsp
|
||||
assert isinstance(rsp, Message)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
@ -1,191 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Date : 2023/7/22 02:40
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
#
|
||||
from tests.metagpt.roles.ui_role import UIDesign
|
||||
|
||||
llm_resp= '''
|
||||
# UI Design Description
|
||||
```The user interface for the snake game will be designed in a way that is simple, clean, and intuitive. The main elements of the game such as the game grid, snake, food, score, and game over message will be clearly defined and easy to understand. The game grid will be centered on the screen with the score displayed at the top. The game controls will be intuitive and easy to use. The design will be modern and minimalist with a pleasing color scheme.```
|
||||
|
||||
## Selected Elements
|
||||
|
||||
Game Grid: The game grid will be a rectangular area in the center of the screen where the game will take place. It will be defined by a border and will have a darker background color.
|
||||
|
||||
Snake: The snake will be represented by a series of connected blocks that move across the grid. The color of the snake will be different from the background color to make it stand out.
|
||||
|
||||
Food: The food will be represented by small objects that are a different color from the snake and the background. The food will be randomly placed on the grid.
|
||||
|
||||
Score: The score will be displayed at the top of the screen. The score will increase each time the snake eats a piece of food.
|
||||
|
||||
Game Over: When the game is over, a message will be displayed in the center of the screen. The player will be given the option to restart the game.
|
||||
|
||||
## HTML Layout
|
||||
```html
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Snake Game</title>
|
||||
<link rel="stylesheet" href="styles.css">
|
||||
</head>
|
||||
<body>
|
||||
<div class="score">Score: 0</div>
|
||||
<div class="game-grid">
|
||||
<!-- Snake and food will be dynamically generated here using JavaScript -->
|
||||
</div>
|
||||
<div class="game-over">Game Over</div>
|
||||
</body>
|
||||
</html>
|
||||
```
|
||||
|
||||
## CSS Styles (styles.css)
|
||||
```css
|
||||
body {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
height: 100vh;
|
||||
margin: 0;
|
||||
background-color: #f0f0f0;
|
||||
}
|
||||
|
||||
.score {
|
||||
font-size: 2em;
|
||||
margin-bottom: 1em;
|
||||
}
|
||||
|
||||
.game-grid {
|
||||
width: 400px;
|
||||
height: 400px;
|
||||
display: grid;
|
||||
grid-template-columns: repeat(20, 1fr);
|
||||
grid-template-rows: repeat(20, 1fr);
|
||||
gap: 1px;
|
||||
background-color: #222;
|
||||
border: 1px solid #555;
|
||||
}
|
||||
|
||||
.snake-segment {
|
||||
background-color: #00cc66;
|
||||
}
|
||||
|
||||
.food {
|
||||
background-color: #cc3300;
|
||||
}
|
||||
|
||||
.control-panel {
|
||||
display: flex;
|
||||
justify-content: space-around;
|
||||
width: 400px;
|
||||
margin-top: 1em;
|
||||
}
|
||||
|
||||
.control-button {
|
||||
padding: 1em;
|
||||
font-size: 1em;
|
||||
border: none;
|
||||
background-color: #555;
|
||||
color: #fff;
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
.game-over {
|
||||
position: absolute;
|
||||
top: 50%;
|
||||
left: 50%;
|
||||
transform: translate(-50%, -50%);
|
||||
font-size: 3em;
|
||||
'''
|
||||
|
||||
def test_ui_design_parse_css():
|
||||
ui_design_work = UIDesign(name="UI design action")
|
||||
|
||||
css = '''
|
||||
body {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
height: 100vh;
|
||||
margin: 0;
|
||||
background-color: #f0f0f0;
|
||||
}
|
||||
|
||||
.score {
|
||||
font-size: 2em;
|
||||
margin-bottom: 1em;
|
||||
}
|
||||
|
||||
.game-grid {
|
||||
width: 400px;
|
||||
height: 400px;
|
||||
display: grid;
|
||||
grid-template-columns: repeat(20, 1fr);
|
||||
grid-template-rows: repeat(20, 1fr);
|
||||
gap: 1px;
|
||||
background-color: #222;
|
||||
border: 1px solid #555;
|
||||
}
|
||||
|
||||
.snake-segment {
|
||||
background-color: #00cc66;
|
||||
}
|
||||
|
||||
.food {
|
||||
background-color: #cc3300;
|
||||
}
|
||||
|
||||
.control-panel {
|
||||
display: flex;
|
||||
justify-content: space-around;
|
||||
width: 400px;
|
||||
margin-top: 1em;
|
||||
}
|
||||
|
||||
.control-button {
|
||||
padding: 1em;
|
||||
font-size: 1em;
|
||||
border: none;
|
||||
background-color: #555;
|
||||
color: #fff;
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
.game-over {
|
||||
position: absolute;
|
||||
top: 50%;
|
||||
left: 50%;
|
||||
transform: translate(-50%, -50%);
|
||||
font-size: 3em;
|
||||
'''
|
||||
assert ui_design_work.parse_css_code(context=llm_resp)==css
|
||||
|
||||
|
||||
def test_ui_design_parse_html():
|
||||
ui_design_work = UIDesign(name="UI design action")
|
||||
|
||||
html = '''
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Snake Game</title>
|
||||
<link rel="stylesheet" href="styles.css">
|
||||
</head>
|
||||
<body>
|
||||
<div class="score">Score: 0</div>
|
||||
<div class="game-grid">
|
||||
<!-- Snake and food will be dynamically generated here using JavaScript -->
|
||||
</div>
|
||||
<div class="game-over">Game Over</div>
|
||||
</body>
|
||||
</html>
|
||||
'''
|
||||
assert ui_design_work.parse_css_code(context=llm_resp)==html
|
||||
|
||||
|
||||
|
||||
|
|
@ -4,31 +4,92 @@
|
|||
@Time : 2023/5/11 17:45
|
||||
@Author : alexanderwu
|
||||
@File : test_write_code.py
|
||||
@Modifiled By: mashenquan, 2023-12-6. According to RFC 135
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.actions.write_code import WriteCode
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import (
|
||||
CODE_SUMMARIES_FILE_REPO,
|
||||
SYSTEM_DESIGN_FILE_REPO,
|
||||
TASK_FILE_REPO,
|
||||
TEST_OUTPUTS_FILE_REPO,
|
||||
)
|
||||
from metagpt.logs import logger
|
||||
from tests.metagpt.actions.mock import TASKS_2, WRITE_CODE_PROMPT_SAMPLE
|
||||
from metagpt.provider.openai_api import OpenAILLM as LLM
|
||||
from metagpt.schema import CodingContext, Document
|
||||
from metagpt.utils.common import aread
|
||||
from metagpt.utils.file_repository import FileRepository
|
||||
from tests.metagpt.actions.mock_markdown import TASKS_2, WRITE_CODE_PROMPT_SAMPLE
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_code():
|
||||
api_design = "设计一个名为'add'的函数,该函数接受两个整数作为输入,并返回它们的和。"
|
||||
write_code = WriteCode("write_code")
|
||||
context = CodingContext(
|
||||
filename="task_filename.py", design_doc=Document(content="设计一个名为'add'的函数,该函数接受两个整数作为输入,并返回它们的和。")
|
||||
)
|
||||
doc = Document(content=context.model_dump_json())
|
||||
write_code = WriteCode(context=doc)
|
||||
|
||||
code = await write_code.run(api_design)
|
||||
logger.info(code)
|
||||
code = await write_code.run()
|
||||
logger.info(code.model_dump_json())
|
||||
|
||||
# 我们不能精确地预测生成的代码,但我们可以检查某些关键字
|
||||
assert 'def add' in code
|
||||
assert 'return' in code
|
||||
assert "def add" in code.code_doc.content
|
||||
assert "return" in code.code_doc.content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_code_directly():
|
||||
prompt = WRITE_CODE_PROMPT_SAMPLE + '\n' + TASKS_2[0]
|
||||
prompt = WRITE_CODE_PROMPT_SAMPLE + "\n" + TASKS_2[0]
|
||||
llm = LLM()
|
||||
rsp = await llm.aask(prompt)
|
||||
logger.info(rsp)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_code_deps():
|
||||
# Prerequisites
|
||||
CONFIG.src_workspace = CONFIG.git_repo.workdir / "snake1/snake1"
|
||||
demo_path = Path(__file__).parent / "../../data/demo_project"
|
||||
await FileRepository.save_file(
|
||||
filename="test_game.py.json",
|
||||
content=await aread(str(demo_path / "test_game.py.json")),
|
||||
relative_path=TEST_OUTPUTS_FILE_REPO,
|
||||
)
|
||||
await FileRepository.save_file(
|
||||
filename="20231221155954.json",
|
||||
content=await aread(str(demo_path / "code_summaries.json")),
|
||||
relative_path=CODE_SUMMARIES_FILE_REPO,
|
||||
)
|
||||
await FileRepository.save_file(
|
||||
filename="20231221155954.json",
|
||||
content=await aread(str(demo_path / "system_design.json")),
|
||||
relative_path=SYSTEM_DESIGN_FILE_REPO,
|
||||
)
|
||||
await FileRepository.save_file(
|
||||
filename="20231221155954.json", content=await aread(str(demo_path / "tasks.json")), relative_path=TASK_FILE_REPO
|
||||
)
|
||||
await FileRepository.save_file(
|
||||
filename="main.py", content='if __name__ == "__main__":\nmain()', relative_path=CONFIG.src_workspace
|
||||
)
|
||||
context = CodingContext(
|
||||
filename="game.py",
|
||||
design_doc=await FileRepository.get_file(filename="20231221155954.json", relative_path=SYSTEM_DESIGN_FILE_REPO),
|
||||
task_doc=await FileRepository.get_file(filename="20231221155954.json", relative_path=TASK_FILE_REPO),
|
||||
code_doc=Document(filename="game.py", content="", root_path="snake1"),
|
||||
)
|
||||
coding_doc = Document(root_path="snake1", filename="game.py", content=context.json())
|
||||
|
||||
action = WriteCode(context=coding_doc)
|
||||
rsp = await action.run()
|
||||
assert rsp
|
||||
assert rsp.code_doc.content
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.actions.write_code_review import WriteCodeReview
|
||||
from metagpt.schema import CodingContext, Document
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -16,13 +17,15 @@ async def test_write_code_review(capfd):
|
|||
def add(a, b):
|
||||
return a +
|
||||
"""
|
||||
# write_code_review = WriteCodeReview("write_code_review")
|
||||
context = CodingContext(
|
||||
filename="math.py", design_doc=Document(content="编写一个从a加b的函数,返回a+b"), code_doc=Document(content=code)
|
||||
)
|
||||
|
||||
code = await WriteCodeReview().run(context="编写一个从a加b的函数,返回a+b", code=code, filename="math.py")
|
||||
context = await WriteCodeReview(context=context).run()
|
||||
|
||||
# 我们不能精确地预测生成的代码评审,但我们可以检查返回的是否为字符串
|
||||
assert isinstance(code, str)
|
||||
assert len(code) > 0
|
||||
assert isinstance(context.code_doc.content, str)
|
||||
assert len(context.code_doc.content) > 0
|
||||
|
||||
captured = capfd.readouterr()
|
||||
print(f"输出内容: {captured.out}")
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import pytest
|
|||
|
||||
from metagpt.actions.write_docstring import WriteDocstring
|
||||
|
||||
code = '''
|
||||
code = """
|
||||
def add_numbers(a: int, b: int):
|
||||
return a + b
|
||||
|
||||
|
|
@ -14,7 +14,7 @@ class Person:
|
|||
|
||||
def greet(self):
|
||||
return f"Hello, my name is {self.name} and I am {self.age} years old."
|
||||
'''
|
||||
"""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -25,8 +25,18 @@ class Person:
|
|||
("numpy", "Parameters"),
|
||||
("sphinx", ":param name:"),
|
||||
],
|
||||
ids=["google", "numpy", "sphinx"]
|
||||
ids=["google", "numpy", "sphinx"],
|
||||
)
|
||||
async def test_write_docstring(style: str, part: str):
|
||||
ret = await WriteDocstring().run(code, style=style)
|
||||
assert part in ret
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write():
|
||||
code = await WriteDocstring.write_docstring(__file__)
|
||||
assert code
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
|
|
@ -4,23 +4,37 @@
|
|||
@Time : 2023/5/11 17:45
|
||||
@Author : alexanderwu
|
||||
@File : test_write_prd.py
|
||||
@Modified By: mashenquan, 2023-11-1. According to Chapter 2.2.1 and 2.2.2 of RFC 116, replace `handle` with `run`.
|
||||
"""
|
||||
import pytest
|
||||
|
||||
from metagpt.actions import BossRequirement
|
||||
from metagpt.actions import UserRequirement, WritePRD
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import DOCS_FILE_REPO, PRDS_FILE_REPO, REQUIREMENT_FILENAME
|
||||
from metagpt.logs import logger
|
||||
from metagpt.roles.product_manager import ProductManager
|
||||
from metagpt.roles.role import RoleReactMode
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.common import any_to_str
|
||||
from metagpt.utils.file_repository import FileRepository
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_prd():
|
||||
async def test_write_prd(new_filename):
|
||||
product_manager = ProductManager()
|
||||
requirements = "开发一个基于大语言模型与私有知识库的搜索引擎,希望可以基于大语言模型进行搜索总结"
|
||||
prd = await product_manager.handle(Message(content=requirements, cause_by=BossRequirement))
|
||||
await FileRepository.save_file(filename=REQUIREMENT_FILENAME, content=requirements, relative_path=DOCS_FILE_REPO)
|
||||
product_manager.rc.react_mode = RoleReactMode.BY_ORDER
|
||||
prd = await product_manager.run(Message(content=requirements, cause_by=UserRequirement))
|
||||
assert prd.cause_by == any_to_str(WritePRD)
|
||||
logger.info(requirements)
|
||||
logger.info(prd)
|
||||
|
||||
# Assert the prd is not None or empty
|
||||
assert prd is not None
|
||||
assert prd != ""
|
||||
assert prd.content != ""
|
||||
assert CONFIG.git_repo.new_file_repository(relative_path=PRDS_FILE_REPO).changed_files
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
|
|
@ -23,10 +23,14 @@ async def test_write_prd_review():
|
|||
Timeline: The feature should be ready for testing in 1.5 months.
|
||||
"""
|
||||
|
||||
write_prd_review = WritePRDReview("write_prd_review")
|
||||
write_prd_review = WritePRDReview(name="write_prd_review")
|
||||
|
||||
prd_review = await write_prd_review.run(prd)
|
||||
|
||||
# We cannot exactly predict the generated PRD review, but we can check if it is a string and if it is not empty
|
||||
assert isinstance(prd_review, str)
|
||||
assert len(prd_review) > 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
53
tests/metagpt/actions/test_write_review.py
Normal file
53
tests/metagpt/actions/test_write_review.py
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/12/20 15:01
|
||||
@Author : alexanderwu
|
||||
@File : test_write_review.py
|
||||
"""
|
||||
import pytest
|
||||
|
||||
from metagpt.actions.write_review import WriteReview
|
||||
|
||||
CONTEXT = """
|
||||
{
|
||||
"Language": "zh_cn",
|
||||
"Programming Language": "Python",
|
||||
"Original Requirements": "写一个简单的2048",
|
||||
"Project Name": "game_2048",
|
||||
"Product Goals": [
|
||||
"创建一个引人入胜的用户体验",
|
||||
"确保高性能",
|
||||
"提供可定制的功能"
|
||||
],
|
||||
"User Stories": [
|
||||
"作为用户,我希望能够选择不同的难度级别",
|
||||
"作为玩家,我希望在每局游戏结束后能看到我的得分"
|
||||
],
|
||||
"Competitive Analysis": [
|
||||
"Python Snake Game: 界面简单,缺乏高级功能"
|
||||
],
|
||||
"Competitive Quadrant Chart": "quadrantChart\n title \"Reach and engagement of campaigns\"\n x-axis \"Low Reach\" --> \"High Reach\"\n y-axis \"Low Engagement\" --> \"High Engagement\"\n quadrant-1 \"我们应该扩展\"\n quadrant-2 \"需要推广\"\n quadrant-3 \"重新评估\"\n quadrant-4 \"可能需要改进\"\n \"Campaign A\": [0.3, 0.6]\n \"Campaign B\": [0.45, 0.23]\n \"Campaign C\": [0.57, 0.69]\n \"Campaign D\": [0.78, 0.34]\n \"Campaign E\": [0.40, 0.34]\n \"Campaign F\": [0.35, 0.78]\n \"Our Target Product\": [0.5, 0.6]",
|
||||
"Requirement Analysis": "产品应该用户友好。",
|
||||
"Requirement Pool": [
|
||||
[
|
||||
"P0",
|
||||
"主要代码..."
|
||||
],
|
||||
[
|
||||
"P0",
|
||||
"游戏算法..."
|
||||
]
|
||||
],
|
||||
"UI Design draft": "基本功能描述,简单的风格和布局。",
|
||||
"Anything UNCLEAR": "..."
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_review():
|
||||
write_review = WriteReview()
|
||||
review = await write_review.run(CONTEXT)
|
||||
assert review.instruct_content
|
||||
assert review.get("LGTM") in ["LGTM", "LBTM"]
|
||||
26
tests/metagpt/actions/test_write_teaching_plan.py
Normal file
26
tests/metagpt/actions/test_write_teaching_plan.py
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/7/28 17:25
|
||||
@Author : mashenquan
|
||||
@File : test_write_teaching_plan.py
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.actions.write_teaching_plan import WriteTeachingPlanPart
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
("topic", "context"),
|
||||
[("Title", "Lesson 1: Learn to draw an apple."), ("Teaching Content", "Lesson 1: Learn to draw an apple.")],
|
||||
)
|
||||
async def test_write_teaching_plan_part(topic, context):
|
||||
action = WriteTeachingPlanPart(topic=topic, context=context)
|
||||
rsp = await action.run()
|
||||
assert rsp
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
@ -9,6 +9,7 @@ import pytest
|
|||
|
||||
from metagpt.actions.write_test import WriteTest
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import Document, TestingContext
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -24,22 +25,17 @@ async def test_write_test():
|
|||
def generate(self, max_y: int, max_x: int):
|
||||
self.position = (random.randint(1, max_y - 1), random.randint(1, max_x - 1))
|
||||
"""
|
||||
context = TestingContext(filename="food.py", code_doc=Document(filename="food.py", content=code))
|
||||
write_test = WriteTest(context=context)
|
||||
|
||||
write_test = WriteTest()
|
||||
|
||||
test_code = await write_test.run(
|
||||
code_to_test=code,
|
||||
test_file_name="test_food.py",
|
||||
source_file_path="/some/dummy/path/cli_snake_game/cli_snake_game/food.py",
|
||||
workspace="/some/dummy/path/cli_snake_game",
|
||||
)
|
||||
logger.info(test_code)
|
||||
context = await write_test.run()
|
||||
logger.info(context.model_dump_json())
|
||||
|
||||
# We cannot exactly predict the generated test cases, but we can check if it is a string and if it is not empty
|
||||
assert isinstance(test_code, str)
|
||||
assert "from cli_snake_game.food import Food" in test_code
|
||||
assert "class TestFood(unittest.TestCase)" in test_code
|
||||
assert "def test_generate" in test_code
|
||||
assert isinstance(context.test_doc.content, str)
|
||||
assert "from food import Food" in context.test_doc.content
|
||||
assert "class TestFood(unittest.TestCase)" in context.test_doc.content
|
||||
assert "def test_generate" in context.test_doc.content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -55,3 +51,7 @@ async def test_write_code_invalid_code(mocker):
|
|||
|
||||
# Assert that the returned code is the same as the invalid code string
|
||||
assert code == "Invalid Code String"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
|
|
@ -9,14 +9,11 @@ from typing import Dict
|
|||
|
||||
import pytest
|
||||
|
||||
from metagpt.actions.write_tutorial import WriteDirectory, WriteContent
|
||||
from metagpt.actions.write_tutorial import WriteContent, WriteDirectory
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
("language", "topic"),
|
||||
[("English", "Write a tutorial about Python")]
|
||||
)
|
||||
@pytest.mark.parametrize(("language", "topic"), [("English", "Write a tutorial about Python")])
|
||||
async def test_write_directory(language: str, topic: str):
|
||||
ret = await WriteDirectory(language=language).run(topic=topic)
|
||||
assert isinstance(ret, dict)
|
||||
|
|
@ -30,7 +27,7 @@ async def test_write_directory(language: str, topic: str):
|
|||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
("language", "topic", "directory"),
|
||||
[("English", "Write a tutorial about Python", {"Introduction": ["What is Python?", "Why learn Python?"]})]
|
||||
[("English", "Write a tutorial about Python", {"Introduction": ["What is Python?", "Why learn Python?"]})],
|
||||
)
|
||||
async def test_write_content(language: str, topic: str, directory: Dict):
|
||||
ret = await WriteContent(language=language, directory=directory).run(topic=topic)
|
||||
|
|
|
|||
|
|
@ -12,12 +12,12 @@ from metagpt.document_store.chromadb_store import ChromaStore
|
|||
def test_chroma_store():
|
||||
"""FIXME:chroma使用感觉很诡异,一用Python就挂,测试用例里也是"""
|
||||
# 创建 ChromaStore 实例,使用 'sample_collection' 集合
|
||||
document_store = ChromaStore('sample_collection_1')
|
||||
document_store = ChromaStore("sample_collection_1")
|
||||
|
||||
# 使用 write 方法添加多个文档
|
||||
document_store.write(["This is document1", "This is document2"],
|
||||
[{"source": "google-docs"}, {"source": "notion"}],
|
||||
["doc1", "doc2"])
|
||||
document_store.write(
|
||||
["This is document1", "This is document2"], [{"source": "google-docs"}, {"source": "notion"}], ["doc1", "doc2"]
|
||||
)
|
||||
|
||||
# 使用 add 方法添加一个文档
|
||||
document_store.add("This is document3", {"source": "notion"}, "doc3")
|
||||
|
|
|
|||
|
|
@ -7,22 +7,22 @@
|
|||
"""
|
||||
import pytest
|
||||
|
||||
from metagpt.const import DATA_PATH
|
||||
from metagpt.document_store.document import Document
|
||||
from metagpt.const import METAGPT_ROOT
|
||||
from metagpt.document import IndexableDocument
|
||||
|
||||
CASES = [
|
||||
("st/faq.xlsx", "Question", "Answer", 1),
|
||||
("cases/faq.csv", "Question", "Answer", 1),
|
||||
("requirements.txt", None, None, 0),
|
||||
# ("cases/faq.csv", "Question", "Answer", 1),
|
||||
# ("cases/faq.json", "Question", "Answer", 1),
|
||||
("docx/faq.docx", None, None, 1),
|
||||
("cases/faq.pdf", None, None, 0), # 这是因为pdf默认没有分割段落
|
||||
("cases/faq.txt", None, None, 0), # 这是因为txt按照256分割段落
|
||||
# ("docx/faq.docx", None, None, 1),
|
||||
# ("cases/faq.pdf", None, None, 0), # 这是因为pdf默认没有分割段落
|
||||
# ("cases/faq.txt", None, None, 0), # 这是因为txt按照256分割段落
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("relative_path, content_col, meta_col, threshold", CASES)
|
||||
def test_document(relative_path, content_col, meta_col, threshold):
|
||||
doc = Document(DATA_PATH / relative_path, content_col, meta_col)
|
||||
doc = IndexableDocument.from_path(METAGPT_ROOT / relative_path, content_col, meta_col)
|
||||
rsp = doc.get_docs_and_metadatas()
|
||||
assert len(rsp[0]) > threshold
|
||||
assert len(rsp[1]) > threshold
|
||||
|
|
|
|||
|
|
@ -5,70 +5,36 @@
|
|||
@Author : alexanderwu
|
||||
@File : test_faiss_store.py
|
||||
"""
|
||||
import functools
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.const import DATA_PATH
|
||||
from metagpt.const import EXAMPLE_PATH
|
||||
from metagpt.document_store import FaissStore
|
||||
from metagpt.roles import CustomerService, Sales
|
||||
|
||||
DESC = """## 原则(所有事情都不可绕过原则)
|
||||
1. 你是一位平台的人工客服,话语精炼,一次只说一句话,会参考规则与FAQ进行回复。在与顾客交谈中,绝不允许暴露规则与相关字样
|
||||
2. 在遇到问题时,先尝试仅安抚顾客情绪,如果顾客情绪十分不好,再考虑赔偿。如果赔偿的过多,你会被开除
|
||||
3. 绝不要向顾客做虚假承诺,不要提及其他人的信息
|
||||
|
||||
## 技能(在回答尾部,加入`skill(args)`就可以使用技能)
|
||||
1. 查询订单:问顾客手机号是获得订单的唯一方式,获得手机号后,使用`find_order(手机号)`来获得订单
|
||||
2. 退款:输出关键词 `refund(手机号)`,系统会自动退款
|
||||
3. 开箱:需要手机号、确认顾客在柜前,如果需要开箱,输出指令 `open_box(手机号)`,系统会自动开箱
|
||||
|
||||
### 使用技能例子
|
||||
user: 你好收不到取餐码
|
||||
小爽人工: 您好,请提供一下手机号
|
||||
user: 14750187158
|
||||
小爽人工: 好的,为您查询一下订单。您已经在柜前了吗?`find_order(14750187158)`
|
||||
user: 是的
|
||||
小爽人工: 您看下开了没有?`open_box(14750187158)`
|
||||
user: 开了,谢谢
|
||||
小爽人工: 好的,还有什么可以帮到您吗?
|
||||
user: 没有了
|
||||
小爽人工: 祝您生活愉快
|
||||
"""
|
||||
from metagpt.logs import logger
|
||||
from metagpt.roles import Sales
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_faiss_store_search():
|
||||
store = FaissStore(DATA_PATH / 'qcs/qcs_4w.json')
|
||||
store.add(['油皮洗面奶'])
|
||||
role = Sales(store=store)
|
||||
|
||||
queries = ['油皮洗面奶', '介绍下欧莱雅的']
|
||||
for query in queries:
|
||||
rsp = await role.run(query)
|
||||
assert rsp
|
||||
|
||||
|
||||
def customer_service():
|
||||
store = FaissStore(DATA_PATH / "st/faq.xlsx", content_col="Question", meta_col="Answer")
|
||||
store.search = functools.partial(store.search, expand_cols=True)
|
||||
role = CustomerService(profile="小爽人工", desc=DESC, store=store)
|
||||
return role
|
||||
async def test_search_json():
|
||||
store = FaissStore(EXAMPLE_PATH / "example.json")
|
||||
role = Sales(profile="Sales", store=store)
|
||||
query = "Which facial cleanser is good for oily skin?"
|
||||
result = await role.run(query)
|
||||
logger.info(result)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_faiss_store_customer_service():
|
||||
allq = [
|
||||
# ["我的餐怎么两小时都没到", "退货吧"],
|
||||
["你好收不到取餐码,麻烦帮我开箱", "14750187158", ]
|
||||
]
|
||||
role = customer_service()
|
||||
for queries in allq:
|
||||
for query in queries:
|
||||
rsp = await role.run(query)
|
||||
assert rsp
|
||||
async def test_search_xlsx():
|
||||
store = FaissStore(EXAMPLE_PATH / "example.xlsx")
|
||||
role = Sales(profile="Sales", store=store)
|
||||
query = "Which facial cleanser is good for oily skin?"
|
||||
result = await role.run(query)
|
||||
logger.info(result)
|
||||
|
||||
|
||||
def test_faiss_store_no_file():
|
||||
with pytest.raises(FileNotFoundError):
|
||||
FaissStore(DATA_PATH / 'wtf.json')
|
||||
@pytest.mark.asyncio
|
||||
async def test_write():
|
||||
store = FaissStore(EXAMPLE_PATH / "example.xlsx", meta_col="Answer", content_col="Question")
|
||||
_faiss_store = store.write()
|
||||
assert _faiss_store.docstore
|
||||
assert _faiss_store.index
|
||||
|
|
|
|||
|
|
@ -5,27 +5,30 @@
|
|||
@Author : unkn-wn (Leon Yee)
|
||||
@File : test_lancedb_store.py
|
||||
"""
|
||||
from metagpt.document_store.lancedb_store import LanceStore
|
||||
import pytest
|
||||
import random
|
||||
|
||||
@pytest
|
||||
from metagpt.document_store.lancedb_store import LanceStore
|
||||
|
||||
|
||||
def test_lance_store():
|
||||
|
||||
# This simply establishes the connection to the database, so we can drop the table if it exists
|
||||
store = LanceStore('test')
|
||||
store = LanceStore("test")
|
||||
|
||||
store.drop('test')
|
||||
store.drop("test")
|
||||
|
||||
store.write(data=[[random.random() for _ in range(100)] for _ in range(2)],
|
||||
metadatas=[{"source": "google-docs"}, {"source": "notion"}],
|
||||
ids=["doc1", "doc2"])
|
||||
store.write(
|
||||
data=[[random.random() for _ in range(100)] for _ in range(2)],
|
||||
metadatas=[{"source": "google-docs"}, {"source": "notion"}],
|
||||
ids=["doc1", "doc2"],
|
||||
)
|
||||
|
||||
store.add(data=[random.random() for _ in range(100)], metadata={"source": "notion"}, _id="doc3")
|
||||
|
||||
result = store.search([random.random() for _ in range(100)], n_results=3)
|
||||
assert(len(result) == 3)
|
||||
assert len(result) == 3
|
||||
|
||||
store.delete("doc2")
|
||||
result = store.search([random.random() for _ in range(100)], n_results=3, where="source = 'notion'", metric='cosine')
|
||||
assert(len(result) == 1)
|
||||
result = store.search(
|
||||
[random.random() for _ in range(100)], n_results=3, where="source = 'notion'", metric="cosine"
|
||||
)
|
||||
assert len(result) == 1
|
||||
|
|
|
|||
|
|
@ -1,36 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/6/11 21:08
|
||||
@Author : alexanderwu
|
||||
@File : test_milvus_store.py
|
||||
"""
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
|
||||
from metagpt.document_store.milvus_store import MilvusConnection, MilvusStore
|
||||
from metagpt.logs import logger
|
||||
|
||||
book_columns = {'idx': int, 'name': str, 'desc': str, 'emb': np.ndarray, 'price': float}
|
||||
book_data = [
|
||||
[i for i in range(10)],
|
||||
[f"book-{i}" for i in range(10)],
|
||||
[f"book-desc-{i}" for i in range(10000, 10010)],
|
||||
[[random.random() for _ in range(2)] for _ in range(10)],
|
||||
[random.random() for _ in range(10)],
|
||||
]
|
||||
|
||||
|
||||
def test_milvus_store():
|
||||
milvus_connection = MilvusConnection(alias="default", host="192.168.50.161", port="30530")
|
||||
milvus_store = MilvusStore(milvus_connection)
|
||||
milvus_store.drop('Book')
|
||||
milvus_store.create_collection('Book', book_columns)
|
||||
milvus_store.add(book_data)
|
||||
milvus_store.build_index('emb')
|
||||
milvus_store.load_collection()
|
||||
|
||||
results = milvus_store.search([[1.0, 1.0]], field='emb')
|
||||
logger.info(results)
|
||||
assert results
|
||||
|
|
@ -24,14 +24,22 @@ random.seed(seed_value)
|
|||
vectors = [[random.random() for _ in range(2)] for _ in range(10)]
|
||||
|
||||
points = [
|
||||
PointStruct(
|
||||
id=idx, vector=vector, payload={"color": "red", "rand_number": idx % 10}
|
||||
)
|
||||
PointStruct(id=idx, vector=vector, payload={"color": "red", "rand_number": idx % 10})
|
||||
for idx, vector in enumerate(vectors)
|
||||
]
|
||||
|
||||
|
||||
def test_milvus_store():
|
||||
def assert_almost_equal(actual, expected):
|
||||
delta = 1e-10
|
||||
if isinstance(expected, list):
|
||||
assert len(actual) == len(expected)
|
||||
for ac, exp in zip(actual, expected):
|
||||
assert abs(ac - exp) <= delta, f"{ac} is not within {delta} of {exp}"
|
||||
else:
|
||||
assert abs(actual - expected) <= delta, f"{actual} is not within {delta} of {expected}"
|
||||
|
||||
|
||||
def test_qdrant_store():
|
||||
qdrant_connection = QdrantConnection(memory=True)
|
||||
vectors_config = VectorParams(size=2, distance=Distance.COSINE)
|
||||
qdrant_store = QdrantStore(qdrant_connection)
|
||||
|
|
@ -44,34 +52,30 @@ def test_milvus_store():
|
|||
qdrant_store.add("Book", points)
|
||||
results = qdrant_store.search("Book", query=[1.0, 1.0])
|
||||
assert results[0]["id"] == 2
|
||||
assert results[0]["score"] == 0.999106722578389
|
||||
assert results[1]["score"] == 7
|
||||
assert results[1]["score"] == 0.9961650411397226
|
||||
assert_almost_equal(results[0]["score"], 0.999106722578389)
|
||||
assert results[1]["id"] == 7
|
||||
assert_almost_equal(results[1]["score"], 0.9961650411397226)
|
||||
results = qdrant_store.search("Book", query=[1.0, 1.0], return_vector=True)
|
||||
assert results[0]["id"] == 2
|
||||
assert results[0]["score"] == 0.999106722578389
|
||||
assert results[0]["vector"] == [0.7363563179969788, 0.6765939593315125]
|
||||
assert results[1]["score"] == 7
|
||||
assert results[1]["score"] == 0.9961650411397226
|
||||
assert results[1]["vector"] == [0.7662628889083862, 0.6425272226333618]
|
||||
assert_almost_equal(results[0]["score"], 0.999106722578389)
|
||||
assert_almost_equal(results[0]["vector"], [0.7363563179969788, 0.6765939593315125])
|
||||
assert results[1]["id"] == 7
|
||||
assert_almost_equal(results[1]["score"], 0.9961650411397226)
|
||||
assert_almost_equal(results[1]["vector"], [0.7662628889083862, 0.6425272226333618])
|
||||
results = qdrant_store.search(
|
||||
"Book",
|
||||
query=[1.0, 1.0],
|
||||
query_filter=Filter(
|
||||
must=[FieldCondition(key="rand_number", range=Range(gte=8))]
|
||||
),
|
||||
query_filter=Filter(must=[FieldCondition(key="rand_number", range=Range(gte=8))]),
|
||||
)
|
||||
assert results[0]["id"] == 8
|
||||
assert results[0]["score"] == 0.9100373450784073
|
||||
assert_almost_equal(results[0]["score"], 0.9100373450784073)
|
||||
assert results[1]["id"] == 9
|
||||
assert results[1]["score"] == 0.7127610621127889
|
||||
assert_almost_equal(results[1]["score"], 0.7127610621127889)
|
||||
results = qdrant_store.search(
|
||||
"Book",
|
||||
query=[1.0, 1.0],
|
||||
query_filter=Filter(
|
||||
must=[FieldCondition(key="rand_number", range=Range(gte=8))]
|
||||
),
|
||||
query_filter=Filter(must=[FieldCondition(key="rand_number", range=Range(gte=8))]),
|
||||
return_vector=True,
|
||||
)
|
||||
assert results[0]["vector"] == [0.35037919878959656, 0.9366079568862915]
|
||||
assert results[1]["vector"] == [0.9999677538871765, 0.00802854634821415]
|
||||
assert_almost_equal(results[0]["vector"], [0.35037919878959656, 0.9366079568862915])
|
||||
assert_almost_equal(results[1]["vector"], [0.9999677538871765, 0.00802854634821415])
|
||||
|
|
|
|||
0
tests/metagpt/learn/__init__.py
Normal file
0
tests/metagpt/learn/__init__.py
Normal file
27
tests/metagpt/learn/test_google_search.py
Normal file
27
tests/metagpt/learn/test_google_search.py
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
import asyncio
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from metagpt.learn.google_search import google_search
|
||||
|
||||
|
||||
async def mock_google_search():
|
||||
class Input(BaseModel):
|
||||
input: str
|
||||
|
||||
inputs = [{"input": "ai agent"}]
|
||||
|
||||
for i in inputs:
|
||||
seed = Input(**i)
|
||||
result = await google_search(seed.input)
|
||||
assert result != ""
|
||||
|
||||
|
||||
def test_suite():
|
||||
loop = asyncio.get_event_loop()
|
||||
task = loop.create_task(mock_google_search())
|
||||
loop.run_until_complete(task)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_suite()
|
||||
46
tests/metagpt/learn/test_skill_loader.py
Normal file
46
tests/metagpt/learn/test_skill_loader.py
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/9/19
|
||||
@Author : mashenquan
|
||||
@File : test_skill_loader.py
|
||||
@Desc : Unit tests.
|
||||
"""
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.learn.skill_loader import SkillsDeclaration
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_suite():
|
||||
CONFIG.agent_skills = [
|
||||
{"id": 1, "name": "text_to_speech", "type": "builtin", "config": {}, "enabled": True},
|
||||
{"id": 2, "name": "text_to_image", "type": "builtin", "config": {}, "enabled": True},
|
||||
{"id": 3, "name": "ai_call", "type": "builtin", "config": {}, "enabled": True},
|
||||
{"id": 3, "name": "data_analysis", "type": "builtin", "config": {}, "enabled": True},
|
||||
{"id": 5, "name": "crawler", "type": "builtin", "config": {"engine": "ddg"}, "enabled": True},
|
||||
{"id": 6, "name": "knowledge", "type": "builtin", "config": {}, "enabled": True},
|
||||
{"id": 6, "name": "web_search", "type": "builtin", "config": {}, "enabled": True},
|
||||
]
|
||||
pathname = Path(__file__).parent / "../../../docs/.well-known/skills.yaml"
|
||||
loader = await SkillsDeclaration.load(skill_yaml_file_name=pathname)
|
||||
skills = loader.get_skill_list()
|
||||
assert skills
|
||||
assert len(skills) >= 3
|
||||
for desc, name in skills.items():
|
||||
assert desc
|
||||
assert name
|
||||
|
||||
entity = loader.entities.get("Assistant")
|
||||
assert entity
|
||||
assert entity.skills
|
||||
for sk in entity.skills:
|
||||
assert sk
|
||||
assert sk.arguments
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
26
tests/metagpt/learn/test_text_to_embedding.py
Normal file
26
tests/metagpt/learn/test_text_to_embedding.py
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/8/18
|
||||
@Author : mashenquan
|
||||
@File : test_text_to_embedding.py
|
||||
@Desc : Unit tests.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.learn.text_to_embedding import text_to_embedding
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_to_embedding():
|
||||
# Prerequisites
|
||||
assert CONFIG.OPENAI_API_KEY
|
||||
|
||||
v = await text_to_embedding(text="Panda emoji")
|
||||
assert len(v.data) > 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
47
tests/metagpt/learn/test_text_to_image.py
Normal file
47
tests/metagpt/learn/test_text_to_image.py
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/8/18
|
||||
@Author : mashenquan
|
||||
@File : test_text_to_image.py
|
||||
@Desc : Unit tests.
|
||||
"""
|
||||
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.learn.text_to_image import text_to_image
|
||||
from metagpt.tools.metagpt_text_to_image import MetaGPTText2Image
|
||||
from metagpt.tools.openai_text_to_image import OpenAIText2Image
|
||||
from metagpt.utils.s3 import S3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_to_image(mocker):
|
||||
# mock
|
||||
mocker.patch.object(MetaGPTText2Image, "text_2_image", return_value=b"mock MetaGPTText2Image")
|
||||
mocker.patch.object(OpenAIText2Image, "text_2_image", return_value=b"mock OpenAIText2Image")
|
||||
mocker.patch.object(S3, "cache", return_value="http://mock/s3")
|
||||
|
||||
# Prerequisites
|
||||
assert CONFIG.METAGPT_TEXT_TO_IMAGE_MODEL_URL
|
||||
assert CONFIG.OPENAI_API_KEY
|
||||
|
||||
data = await text_to_image("Panda emoji", size_type="512x512")
|
||||
assert "base64" in data or "http" in data
|
||||
|
||||
# Mock session env
|
||||
old_options = CONFIG.options.copy()
|
||||
new_options = old_options.copy()
|
||||
new_options["METAGPT_TEXT_TO_IMAGE_MODEL_URL"] = None
|
||||
CONFIG.set_context(new_options)
|
||||
try:
|
||||
data = await text_to_image("Panda emoji", size_type="512x512")
|
||||
assert "base64" in data or "http" in data
|
||||
finally:
|
||||
CONFIG.set_context(old_options)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
43
tests/metagpt/learn/test_text_to_speech.py
Normal file
43
tests/metagpt/learn/test_text_to_speech.py
Normal file
|
|
@ -0,0 +1,43 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/8/18
|
||||
@Author : mashenquan
|
||||
@File : test_text_to_speech.py
|
||||
@Desc : Unit tests.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.learn.text_to_speech import text_to_speech
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_to_speech():
|
||||
# Prerequisites
|
||||
assert CONFIG.IFLYTEK_APP_ID
|
||||
assert CONFIG.IFLYTEK_API_KEY
|
||||
assert CONFIG.IFLYTEK_API_SECRET
|
||||
assert CONFIG.AZURE_TTS_SUBSCRIPTION_KEY and CONFIG.AZURE_TTS_SUBSCRIPTION_KEY != "YOUR_API_KEY"
|
||||
assert CONFIG.AZURE_TTS_REGION
|
||||
|
||||
# test azure
|
||||
data = await text_to_speech("panda emoji")
|
||||
assert "base64" in data or "http" in data
|
||||
|
||||
# test iflytek
|
||||
## Mock session env
|
||||
old_options = CONFIG.options.copy()
|
||||
new_options = old_options.copy()
|
||||
new_options["AZURE_TTS_SUBSCRIPTION_KEY"] = ""
|
||||
CONFIG.set_context(new_options)
|
||||
try:
|
||||
data = await text_to_speech("panda emoji")
|
||||
assert "base64" in data or "http" in data
|
||||
finally:
|
||||
CONFIG.set_context(old_options)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
@ -14,9 +14,9 @@ def test_skill_manager():
|
|||
manager = SkillManager()
|
||||
logger.info(manager._store)
|
||||
|
||||
write_prd = WritePRD("WritePRD")
|
||||
write_prd = WritePRD(name="WritePRD")
|
||||
write_prd.desc = "基于老板或其他人的需求进行PRD的撰写,包括用户故事、需求分解等"
|
||||
write_test = WriteTest("WriteTest")
|
||||
write_test = WriteTest(name="WriteTest")
|
||||
write_test.desc = "进行测试用例的撰写"
|
||||
manager.add_skill(write_prd)
|
||||
manager.add_skill(write_test)
|
||||
|
|
@ -24,13 +24,13 @@ def test_skill_manager():
|
|||
skill = manager.get_skill("WriteTest")
|
||||
logger.info(skill)
|
||||
|
||||
rsp = manager.retrieve_skill("写PRD")
|
||||
rsp = manager.retrieve_skill("WritePRD")
|
||||
logger.info(rsp)
|
||||
assert rsp[0] == "WritePRD"
|
||||
|
||||
rsp = manager.retrieve_skill("写测试用例")
|
||||
logger.info(rsp)
|
||||
assert rsp[0] == 'WriteTest'
|
||||
assert rsp[0] == "WriteTest"
|
||||
|
||||
rsp = manager.retrieve_skill_scored("写PRD")
|
||||
logger.info(rsp)
|
||||
|
|
|
|||
71
tests/metagpt/memory/test_brain_memory.py
Normal file
71
tests/metagpt/memory/test_brain_memory.py
Normal file
|
|
@ -0,0 +1,71 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/8/27
|
||||
@Author : mashenquan
|
||||
@File : test_brain_memory.py
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.config import LLMProviderEnum
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.memory.brain_memory import BrainMemory
|
||||
from metagpt.schema import Message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_memory():
|
||||
memory = BrainMemory()
|
||||
memory.add_talk(Message(content="talk"))
|
||||
assert memory.history[0].role == "user"
|
||||
memory.add_answer(Message(content="answer"))
|
||||
assert memory.history[1].role == "assistant"
|
||||
redis_key = BrainMemory.to_redis_key("none", "user_id", "chat_id")
|
||||
await memory.dumps(redis_key=redis_key)
|
||||
assert memory.exists("talk")
|
||||
assert 1 == memory.to_int("1", 0)
|
||||
memory.last_talk = "AAA"
|
||||
assert memory.pop_last_talk() == "AAA"
|
||||
assert memory.last_talk is None
|
||||
assert memory.is_history_available
|
||||
assert memory.history_text
|
||||
|
||||
memory = await BrainMemory.loads(redis_key=redis_key)
|
||||
assert memory
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("input", "tag", "val"),
|
||||
[("[TALK]:Hello", "TALK", "Hello"), ("Hello", None, "Hello"), ("[TALK]Hello", None, "[TALK]Hello")],
|
||||
)
|
||||
def test_extract_info(input, tag, val):
|
||||
t, v = BrainMemory.extract_info(input)
|
||||
assert tag == t
|
||||
assert val == v
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("llm", [LLM(provider=LLMProviderEnum.OPENAI), LLM(provider=LLMProviderEnum.METAGPT)])
|
||||
async def test_memory_llm(llm):
|
||||
memory = BrainMemory()
|
||||
for i in range(500):
|
||||
memory.add_talk(Message(content="Lily is a girl.\n"))
|
||||
|
||||
res = await memory.is_related("apple", "moon", llm)
|
||||
assert not res
|
||||
|
||||
res = await memory.rewrite(sentence="apple Lily eating", context="", llm=llm)
|
||||
assert "Lily" in res
|
||||
|
||||
res = await memory.summarize(llm=llm)
|
||||
assert res
|
||||
|
||||
res = await memory.get_title(llm=llm)
|
||||
assert res
|
||||
assert "Lily" in res
|
||||
assert memory.history or memory.historical_summary
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
@ -1,38 +1,50 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : unittest of `metagpt/memory/longterm_memory.py`
|
||||
"""
|
||||
@Desc : unittest of `metagpt/memory/longterm_memory.py`
|
||||
@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.actions import UserRequirement
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.schema import Message
|
||||
from metagpt.actions import BossRequirement
|
||||
from metagpt.memory.longterm_memory import LongTermMemory
|
||||
from metagpt.roles.role import RoleContext
|
||||
from metagpt.memory import LongTermMemory
|
||||
from metagpt.schema import Message
|
||||
|
||||
|
||||
def test_ltm_search():
|
||||
assert hasattr(CONFIG, "long_term_memory") is True
|
||||
openai_api_key = CONFIG.openai_api_key
|
||||
assert len(openai_api_key) > 20
|
||||
os.environ.setdefault("OPENAI_API_KEY", CONFIG.openai_api_key)
|
||||
assert len(CONFIG.openai_api_key) > 20
|
||||
|
||||
role_id = 'UTUserLtm(Product Manager)'
|
||||
rc = RoleContext(watch=[BossRequirement])
|
||||
role_id = "UTUserLtm(Product Manager)"
|
||||
from metagpt.environment import Environment
|
||||
|
||||
Environment
|
||||
RoleContext.model_rebuild()
|
||||
rc = RoleContext(watch={"metagpt.actions.add_requirement.UserRequirement"})
|
||||
ltm = LongTermMemory()
|
||||
ltm.recover_memory(role_id, rc)
|
||||
|
||||
idea = 'Write a cli snake game'
|
||||
message = Message(role='BOSS', content=idea, cause_by=BossRequirement)
|
||||
idea = "Write a cli snake game"
|
||||
message = Message(role="User", content=idea, cause_by=UserRequirement)
|
||||
news = ltm.find_news([message])
|
||||
assert len(news) == 1
|
||||
ltm.add(message)
|
||||
|
||||
sim_idea = 'Write a game of cli snake'
|
||||
sim_message = Message(role='BOSS', content=sim_idea, cause_by=BossRequirement)
|
||||
sim_idea = "Write a game of cli snake"
|
||||
|
||||
sim_message = Message(role="User", content=sim_idea, cause_by=UserRequirement)
|
||||
news = ltm.find_news([sim_message])
|
||||
assert len(news) == 0
|
||||
ltm.add(sim_message)
|
||||
|
||||
new_idea = 'Write a 2048 web game'
|
||||
new_message = Message(role='BOSS', content=new_idea, cause_by=BossRequirement)
|
||||
new_idea = "Write a 2048 web game"
|
||||
new_message = Message(role="User", content=new_idea, cause_by=UserRequirement)
|
||||
news = ltm.find_news([new_message])
|
||||
assert len(news) == 1
|
||||
ltm.add(new_message)
|
||||
|
|
@ -47,9 +59,13 @@ def test_ltm_search():
|
|||
news = ltm_new.find_news([sim_message])
|
||||
assert len(news) == 0
|
||||
|
||||
new_idea = 'Write a Battle City'
|
||||
new_message = Message(role='BOSS', content=new_idea, cause_by=BossRequirement)
|
||||
new_idea = "Write a Battle City"
|
||||
new_message = Message(role="User", content=new_idea, cause_by=UserRequirement)
|
||||
news = ltm_new.find_news([new_message])
|
||||
assert len(news) == 1
|
||||
|
||||
ltm_new.clear()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
57
tests/metagpt/memory/test_memory.py
Normal file
57
tests/metagpt/memory/test_memory.py
Normal file
|
|
@ -0,0 +1,57 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the unittest of Memory
|
||||
|
||||
from metagpt.actions import UserRequirement
|
||||
from metagpt.memory.memory import Memory
|
||||
from metagpt.schema import Message
|
||||
|
||||
|
||||
def test_memory():
|
||||
memory = Memory()
|
||||
|
||||
message1 = Message(content="test message1", role="user1")
|
||||
message2 = Message(content="test message2", role="user2")
|
||||
message3 = Message(content="test message3", role="user1")
|
||||
memory.add(message1)
|
||||
assert memory.count() == 1
|
||||
|
||||
memory.delete_newest()
|
||||
assert memory.count() == 0
|
||||
|
||||
memory.add_batch([message1, message2])
|
||||
assert memory.count() == 2
|
||||
assert len(memory.index.get(message1.cause_by)) == 2
|
||||
|
||||
messages = memory.get_by_role("user1")
|
||||
assert messages[0].content == message1.content
|
||||
|
||||
messages = memory.get_by_content("test message")
|
||||
assert len(messages) == 2
|
||||
|
||||
messages = memory.get_by_action(UserRequirement)
|
||||
assert len(messages) == 2
|
||||
|
||||
messages = memory.get_by_actions([UserRequirement])
|
||||
assert len(messages) == 2
|
||||
|
||||
messages = memory.try_remember("test message")
|
||||
assert len(messages) == 2
|
||||
|
||||
messages = memory.get(k=1)
|
||||
assert len(messages) == 1
|
||||
|
||||
messages = memory.get(k=5)
|
||||
assert len(messages) == 2
|
||||
|
||||
messages = memory.find_news([message3])
|
||||
assert len(messages) == 1
|
||||
|
||||
memory.delete(message1)
|
||||
assert memory.count() == 1
|
||||
messages = memory.get_by_role("user2")
|
||||
assert messages[0].content == message2.content
|
||||
|
||||
memory.clear()
|
||||
assert memory.count() == 0
|
||||
assert len(memory.index) == 0
|
||||
|
|
@ -1,20 +1,30 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the unittests of metagpt/memory/memory_storage.py
|
||||
"""
|
||||
@Desc : the unittests of metagpt/memory/memory_storage.py
|
||||
"""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from metagpt.actions import UserRequirement, WritePRD
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import DATA_PATH
|
||||
from metagpt.memory.memory_storage import MemoryStorage
|
||||
from metagpt.schema import Message
|
||||
from metagpt.actions import BossRequirement
|
||||
from metagpt.actions import WritePRD
|
||||
from metagpt.actions.action_output import ActionOutput
|
||||
|
||||
os.environ.setdefault("OPENAI_API_KEY", CONFIG.openai_api_key)
|
||||
|
||||
|
||||
def test_idea_message():
|
||||
idea = 'Write a cli snake game'
|
||||
role_id = 'UTUser1(Product Manager)'
|
||||
message = Message(role='BOSS', content=idea, cause_by=BossRequirement)
|
||||
idea = "Write a cli snake game"
|
||||
role_id = "UTUser1(Product Manager)"
|
||||
message = Message(role="User", content=idea, cause_by=UserRequirement)
|
||||
|
||||
shutil.rmtree(Path(DATA_PATH / f"role_mem/{role_id}/"), ignore_errors=True)
|
||||
|
||||
memory_storage: MemoryStorage = MemoryStorage()
|
||||
messages = memory_storage.recover_memory(role_id)
|
||||
|
|
@ -23,14 +33,14 @@ def test_idea_message():
|
|||
memory_storage.add(message)
|
||||
assert memory_storage.is_initialized is True
|
||||
|
||||
sim_idea = 'Write a game of cli snake'
|
||||
sim_message = Message(role='BOSS', content=sim_idea, cause_by=BossRequirement)
|
||||
new_messages = memory_storage.search(sim_message)
|
||||
assert len(new_messages) == 0 # similar, return []
|
||||
sim_idea = "Write a game of cli snake"
|
||||
sim_message = Message(role="User", content=sim_idea, cause_by=UserRequirement)
|
||||
new_messages = memory_storage.search_dissimilar(sim_message)
|
||||
assert len(new_messages) == 0 # similar, return []
|
||||
|
||||
new_idea = 'Write a 2048 web game'
|
||||
new_message = Message(role='BOSS', content=new_idea, cause_by=BossRequirement)
|
||||
new_messages = memory_storage.search(new_message)
|
||||
new_idea = "Write a 2048 web game"
|
||||
new_message = Message(role="User", content=new_idea, cause_by=UserRequirement)
|
||||
new_messages = memory_storage.search_dissimilar(new_message)
|
||||
assert new_messages[0].content == message.content
|
||||
|
||||
memory_storage.clean()
|
||||
|
|
@ -38,22 +48,17 @@ def test_idea_message():
|
|||
|
||||
|
||||
def test_actionout_message():
|
||||
out_mapping = {
|
||||
'field1': (str, ...),
|
||||
'field2': (List[str], ...)
|
||||
}
|
||||
out_data = {
|
||||
'field1': 'field1 value',
|
||||
'field2': ['field2 value1', 'field2 value2']
|
||||
}
|
||||
ic_obj = ActionOutput.create_model_class('prd', out_mapping)
|
||||
out_mapping = {"field1": (str, ...), "field2": (List[str], ...)}
|
||||
out_data = {"field1": "field1 value", "field2": ["field2 value1", "field2 value2"]}
|
||||
ic_obj = ActionNode.create_model_class("prd", out_mapping)
|
||||
|
||||
role_id = 'UTUser2(Architect)'
|
||||
content = 'The boss has requested the creation of a command-line interface (CLI) snake game'
|
||||
message = Message(content=content,
|
||||
instruct_content=ic_obj(**out_data),
|
||||
role='user',
|
||||
cause_by=WritePRD) # WritePRD as test action
|
||||
role_id = "UTUser2(Architect)"
|
||||
content = "The user has requested the creation of a command-line interface (CLI) snake game"
|
||||
message = Message(
|
||||
content=content, instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD
|
||||
) # WritePRD as test action
|
||||
|
||||
shutil.rmtree(Path(DATA_PATH / f"role_mem/{role_id}/"), ignore_errors=True)
|
||||
|
||||
memory_storage: MemoryStorage = MemoryStorage()
|
||||
messages = memory_storage.recover_memory(role_id)
|
||||
|
|
@ -62,20 +67,14 @@ def test_actionout_message():
|
|||
memory_storage.add(message)
|
||||
assert memory_storage.is_initialized is True
|
||||
|
||||
sim_conent = 'The request is command-line interface (CLI) snake game'
|
||||
sim_message = Message(content=sim_conent,
|
||||
instruct_content=ic_obj(**out_data),
|
||||
role='user',
|
||||
cause_by=WritePRD)
|
||||
new_messages = memory_storage.search(sim_message)
|
||||
assert len(new_messages) == 0 # similar, return []
|
||||
sim_conent = "The request is command-line interface (CLI) snake game"
|
||||
sim_message = Message(content=sim_conent, instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD)
|
||||
new_messages = memory_storage.search_dissimilar(sim_message)
|
||||
assert len(new_messages) == 0 # similar, return []
|
||||
|
||||
new_conent = 'Incorporate basic features of a snake game such as scoring and increasing difficulty'
|
||||
new_message = Message(content=new_conent,
|
||||
instruct_content=ic_obj(**out_data),
|
||||
role='user',
|
||||
cause_by=WritePRD)
|
||||
new_messages = memory_storage.search(new_message)
|
||||
new_conent = "Incorporate basic features of a snake game such as scoring and increasing difficulty"
|
||||
new_message = Message(content=new_conent, instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD)
|
||||
new_messages = memory_storage.search_dissimilar(new_message)
|
||||
assert new_messages[0].content == message.content
|
||||
|
||||
memory_storage.clean()
|
||||
|
|
|
|||
|
|
@ -4,12 +4,14 @@
|
|||
@Time : 2023/9/16 20:03
|
||||
@Author : femto Zheng
|
||||
@File : test_basic_planner.py
|
||||
@Modified By: mashenquan, 2023-11-1. In accordance with Chapter 2.2.1 and 2.2.2 of RFC 116, utilize the new message
|
||||
distribution feature for message handling.
|
||||
"""
|
||||
import pytest
|
||||
from semantic_kernel.core_skills import FileIOSkill, MathSkill, TextSkill, TimeSkill
|
||||
from semantic_kernel.planning.action_planner.action_planner import ActionPlanner
|
||||
|
||||
from metagpt.actions import BossRequirement
|
||||
from metagpt.actions import UserRequirement
|
||||
from metagpt.roles.sk_agent import SkAgent
|
||||
from metagpt.schema import Message
|
||||
|
||||
|
|
@ -23,7 +25,8 @@ async def test_action_planner():
|
|||
role.import_skill(TimeSkill(), "time")
|
||||
role.import_skill(TextSkill(), "text")
|
||||
task = "What is the sum of 110 and 990?"
|
||||
role.recv(Message(content=task, cause_by=BossRequirement))
|
||||
|
||||
role.put_message(Message(content=task, cause_by=UserRequirement))
|
||||
await role._observe()
|
||||
await role._think() # it will choose mathskill.Add
|
||||
assert "1100" == (await role._act()).content
|
||||
|
|
|
|||
|
|
@ -4,11 +4,13 @@
|
|||
@Time : 2023/9/16 20:03
|
||||
@Author : femto Zheng
|
||||
@File : test_basic_planner.py
|
||||
@Modified By: mashenquan, 2023-11-1. In accordance with Chapter 2.2.1 and 2.2.2 of RFC 116, utilize the new message
|
||||
distribution feature for message handling.
|
||||
"""
|
||||
import pytest
|
||||
from semantic_kernel.core_skills import TextSkill
|
||||
|
||||
from metagpt.actions import BossRequirement
|
||||
from metagpt.actions import UserRequirement
|
||||
from metagpt.const import SKILL_DIRECTORY
|
||||
from metagpt.roles.sk_agent import SkAgent
|
||||
from metagpt.schema import Message
|
||||
|
|
@ -26,7 +28,8 @@ async def test_basic_planner():
|
|||
role.import_semantic_skill_from_directory(SKILL_DIRECTORY, "WriterSkill")
|
||||
role.import_skill(TextSkill(), "TextSkill")
|
||||
# using BasicPlanner
|
||||
role.recv(Message(content=task, cause_by=BossRequirement))
|
||||
role.put_message(Message(content=task, cause_by=UserRequirement))
|
||||
await role._observe()
|
||||
await role._think()
|
||||
# assuming sk_agent will think he needs WriterSkill.Brainstorm and WriterSkill.Translate
|
||||
assert "WriterSkill.Brainstorm" in role.plan.generated_plan.result
|
||||
|
|
|
|||
8
tests/metagpt/provider/conftest.py
Normal file
8
tests/metagpt/provider/conftest.py
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def llm_mock(rsp_cache, mocker, request):
|
||||
# An empty fixture to overwrite the global llm_mock fixture
|
||||
# because in provider folder, we want to test the aask and aask functions for the specific models
|
||||
pass
|
||||
3
tests/metagpt/provider/postprocess/__init__.py
Normal file
3
tests/metagpt/provider/postprocess/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
|
|
@ -0,0 +1,31 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
|
||||
|
||||
from metagpt.provider.postprocess.base_postprocess_plugin import BasePostProcessPlugin
|
||||
|
||||
raw_output = """
|
||||
[CONTENT]
|
||||
{
|
||||
"Original Requirements": "xxx"
|
||||
}
|
||||
[/CONTENT]
|
||||
"""
|
||||
raw_schema = {
|
||||
"title": "prd",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"Original Requirements": {"title": "Original Requirements", "type": "string"},
|
||||
},
|
||||
"required": [
|
||||
"Original Requirements",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def test_llm_post_process_plugin():
|
||||
post_process_plugin = BasePostProcessPlugin()
|
||||
|
||||
output = post_process_plugin.run(output=raw_output, schema=raw_schema)
|
||||
assert "Original Requirements" in output
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
|
||||
|
||||
from metagpt.provider.postprocess.llm_output_postprocess import llm_output_postprocess
|
||||
from tests.metagpt.provider.postprocess.test_base_postprocess_plugin import (
|
||||
raw_output,
|
||||
raw_schema,
|
||||
)
|
||||
|
||||
|
||||
def test_llm_output_postprocess():
|
||||
output = llm_output_postprocess(output=raw_output, schema=raw_schema)
|
||||
assert "Original Requirements" in output
|
||||
34
tests/metagpt/provider/test_anthropic_api.py
Normal file
34
tests/metagpt/provider/test_anthropic_api.py
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the unittest of Claude2
|
||||
|
||||
|
||||
import pytest
|
||||
from anthropic.resources.completions import Completion
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.provider.anthropic_api import Claude2
|
||||
|
||||
CONFIG.anthropic_api_key = "xxx"
|
||||
|
||||
prompt = "who are you"
|
||||
resp = "I'am Claude2"
|
||||
|
||||
|
||||
def mock_anthropic_completions_create(self, model: str, prompt: str, max_tokens_to_sample: int) -> Completion:
|
||||
return Completion(id="xx", completion=resp, model="claude-2", stop_reason="stop_sequence", type="completion")
|
||||
|
||||
|
||||
async def mock_anthropic_acompletions_create(self, model: str, prompt: str, max_tokens_to_sample: int) -> Completion:
|
||||
return Completion(id="xx", completion=resp, model="claude-2", stop_reason="stop_sequence", type="completion")
|
||||
|
||||
|
||||
def test_claude2_ask(mocker):
|
||||
mocker.patch("anthropic.resources.completions.Completions.create", mock_anthropic_completions_create)
|
||||
assert resp == Claude2().ask(prompt)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_claude2_aask(mocker):
|
||||
mocker.patch("anthropic.resources.completions.AsyncCompletions.create", mock_anthropic_acompletions_create)
|
||||
assert resp == await Claude2().aask(prompt)
|
||||
14
tests/metagpt/provider/test_azure_openai_api.py
Normal file
14
tests/metagpt/provider/test_azure_openai_api.py
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.provider.azure_openai_api import AzureOpenAILLM
|
||||
|
||||
CONFIG.OPENAI_API_VERSION = "xx"
|
||||
CONFIG.openai_proxy = "http://127.0.0.1:80" # fake value
|
||||
|
||||
|
||||
def test_azure_openai_api():
|
||||
_ = AzureOpenAILLM()
|
||||
|
|
@ -3,13 +3,104 @@
|
|||
"""
|
||||
@Time : 2023/5/7 17:40
|
||||
@Author : alexanderwu
|
||||
@File : test_base_gpt_api.py
|
||||
@File : test_base_llm.py
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.schema import Message
|
||||
|
||||
default_chat_resp = {
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "I'am GPT",
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
]
|
||||
}
|
||||
prompt_msg = "who are you"
|
||||
resp_content = default_chat_resp["choices"][0]["message"]["content"]
|
||||
|
||||
def test_message():
|
||||
message = Message(role='user', content='wtf')
|
||||
assert 'role' in message.to_dict()
|
||||
assert 'user' in str(message)
|
||||
|
||||
class MockBaseLLM(BaseLLM):
|
||||
def completion(self, messages: list[dict], timeout=3):
|
||||
return default_chat_resp
|
||||
|
||||
async def acompletion(self, messages: list[dict], timeout=3):
|
||||
return default_chat_resp
|
||||
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str:
|
||||
return resp_content
|
||||
|
||||
async def close(self):
|
||||
return default_chat_resp
|
||||
|
||||
|
||||
def test_base_llm():
|
||||
message = Message(role="user", content="hello")
|
||||
assert "role" in message.to_dict()
|
||||
assert "user" in str(message)
|
||||
|
||||
base_llm = MockBaseLLM()
|
||||
|
||||
openai_funccall_resp = {
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "test",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_Y5r6Ddr2Qc2ZrqgfwzPX5l72",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "execute",
|
||||
"arguments": '{\n "language": "python",\n "code": "print(\'Hello, World!\')"\n}',
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
]
|
||||
}
|
||||
func: dict = base_llm.get_choice_function(openai_funccall_resp)
|
||||
assert func == {
|
||||
"name": "execute",
|
||||
"arguments": '{\n "language": "python",\n "code": "print(\'Hello, World!\')"\n}',
|
||||
}
|
||||
|
||||
func_args: dict = base_llm.get_choice_function_arguments(openai_funccall_resp)
|
||||
assert func_args == {"language": "python", "code": "print('Hello, World!')"}
|
||||
|
||||
choice_text = base_llm.get_choice_text(openai_funccall_resp)
|
||||
assert choice_text == openai_funccall_resp["choices"][0]["message"]["content"]
|
||||
|
||||
# resp = base_llm.ask(prompt_msg)
|
||||
# assert resp == resp_content
|
||||
|
||||
# resp = base_llm.ask_batch([prompt_msg])
|
||||
# assert resp == resp_content
|
||||
|
||||
# resp = base_llm.ask_code([prompt_msg])
|
||||
# assert resp == resp_content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_base_llm():
|
||||
base_llm = MockBaseLLM()
|
||||
|
||||
resp = await base_llm.aask(prompt_msg)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await base_llm.aask_batch([prompt_msg])
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await base_llm.aask_code([prompt_msg])
|
||||
assert resp == resp_content
|
||||
|
|
|
|||
118
tests/metagpt/provider/test_fireworks_api.py
Normal file
118
tests/metagpt/provider/test_fireworks_api.py
Normal file
|
|
@ -0,0 +1,118 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the unittest of fireworks api
|
||||
|
||||
import pytest
|
||||
from openai.types.chat.chat_completion import (
|
||||
ChatCompletion,
|
||||
ChatCompletionMessage,
|
||||
Choice,
|
||||
)
|
||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
from openai.types.chat.chat_completion_chunk import Choice as AChoice
|
||||
from openai.types.chat.chat_completion_chunk import ChoiceDelta
|
||||
from openai.types.completion_usage import CompletionUsage
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.provider.fireworks_api import (
|
||||
MODEL_GRADE_TOKEN_COSTS,
|
||||
FireworksCostManager,
|
||||
FireworksLLM,
|
||||
)
|
||||
from metagpt.utils.cost_manager import Costs
|
||||
|
||||
CONFIG.fireworks_api_key = "xxx"
|
||||
CONFIG.max_budget = 10
|
||||
CONFIG.calc_usage = True
|
||||
|
||||
resp_content = "I'm fireworks"
|
||||
default_resp = ChatCompletion(
|
||||
id="cmpl-a6652c1bb181caae8dd19ad8",
|
||||
model="accounts/fireworks/models/llama-v2-13b-chat",
|
||||
object="chat.completion",
|
||||
created=1703300855,
|
||||
choices=[
|
||||
Choice(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message=ChatCompletionMessage(role="assistant", content=resp_content),
|
||||
logprobs=None,
|
||||
)
|
||||
],
|
||||
usage=CompletionUsage(completion_tokens=110, prompt_tokens=92, total_tokens=202),
|
||||
)
|
||||
|
||||
default_resp_chunk = ChatCompletionChunk(
|
||||
id=default_resp.id,
|
||||
model=default_resp.model,
|
||||
object="chat.completion.chunk",
|
||||
created=default_resp.created,
|
||||
choices=[
|
||||
AChoice(
|
||||
delta=ChoiceDelta(content=resp_content, role="assistant"),
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
logprobs=None,
|
||||
)
|
||||
],
|
||||
usage=dict(default_resp.usage),
|
||||
)
|
||||
|
||||
prompt_msg = "who are you"
|
||||
messages = [{"role": "user", "content": prompt_msg}]
|
||||
|
||||
|
||||
def test_fireworks_costmanager():
|
||||
cost_manager = FireworksCostManager()
|
||||
assert MODEL_GRADE_TOKEN_COSTS["-1"] == cost_manager.model_grade_token_costs("test")
|
||||
assert MODEL_GRADE_TOKEN_COSTS["-1"] == cost_manager.model_grade_token_costs("xxx-81b-chat")
|
||||
assert MODEL_GRADE_TOKEN_COSTS["16"] == cost_manager.model_grade_token_costs("llama-v2-13b-chat")
|
||||
assert MODEL_GRADE_TOKEN_COSTS["16"] == cost_manager.model_grade_token_costs("xxx-15.5b-chat")
|
||||
assert MODEL_GRADE_TOKEN_COSTS["16"] == cost_manager.model_grade_token_costs("xxx-16b-chat")
|
||||
assert MODEL_GRADE_TOKEN_COSTS["80"] == cost_manager.model_grade_token_costs("xxx-80b-chat")
|
||||
assert MODEL_GRADE_TOKEN_COSTS["mixtral-8x7b"] == cost_manager.model_grade_token_costs("mixtral-8x7b-chat")
|
||||
|
||||
cost_manager.update_cost(prompt_tokens=500000, completion_tokens=500000, model="llama-v2-13b-chat")
|
||||
assert cost_manager.total_cost == 0.5
|
||||
|
||||
|
||||
async def mock_openai_acompletions_create(self, stream: bool = False, **kwargs) -> ChatCompletionChunk:
|
||||
if stream:
|
||||
|
||||
class Iterator(object):
|
||||
async def __aiter__(self):
|
||||
yield default_resp_chunk
|
||||
|
||||
return Iterator()
|
||||
else:
|
||||
return default_resp
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fireworks_acompletion(mocker):
|
||||
mocker.patch("openai.resources.chat.completions.AsyncCompletions.create", mock_openai_acompletions_create)
|
||||
|
||||
fireworks_gpt = FireworksLLM()
|
||||
fireworks_gpt.model = "llama-v2-13b-chat"
|
||||
|
||||
fireworks_gpt._update_costs(
|
||||
usage=CompletionUsage(prompt_tokens=500000, completion_tokens=500000, total_tokens=1000000)
|
||||
)
|
||||
assert fireworks_gpt.get_costs() == Costs(
|
||||
total_prompt_tokens=500000, total_completion_tokens=500000, total_cost=0.5, total_budget=0
|
||||
)
|
||||
|
||||
resp = await fireworks_gpt.acompletion(messages)
|
||||
assert resp.choices[0].message.content in resp_content
|
||||
|
||||
resp = await fireworks_gpt.aask(prompt_msg, stream=False)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await fireworks_gpt.acompletion_text(messages, stream=False)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await fireworks_gpt.acompletion_text(messages, stream=True)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await fireworks_gpt.aask(prompt_msg)
|
||||
assert resp == resp_content
|
||||
136
tests/metagpt/provider/test_general_api_base.py
Normal file
136
tests/metagpt/provider/test_general_api_base.py
Normal file
|
|
@ -0,0 +1,136 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
|
||||
import os
|
||||
from typing import AsyncGenerator, Generator, Iterator, Tuple, Union
|
||||
|
||||
import aiohttp
|
||||
import pytest
|
||||
import requests
|
||||
from openai import OpenAIError
|
||||
|
||||
from metagpt.provider.general_api_base import (
|
||||
APIRequestor,
|
||||
ApiType,
|
||||
OpenAIResponse,
|
||||
_aiohttp_proxies_arg,
|
||||
_build_api_url,
|
||||
_make_session,
|
||||
_requests_proxies_arg,
|
||||
log_debug,
|
||||
log_info,
|
||||
log_warn,
|
||||
logfmt,
|
||||
parse_stream,
|
||||
parse_stream_helper,
|
||||
)
|
||||
|
||||
|
||||
def test_basic():
|
||||
_ = ApiType.from_str("azure")
|
||||
_ = ApiType.from_str("azuread")
|
||||
_ = ApiType.from_str("openai")
|
||||
with pytest.raises(OpenAIError):
|
||||
_ = ApiType.from_str("xx")
|
||||
|
||||
os.environ.setdefault("LLM_LOG", "debug")
|
||||
log_debug("debug")
|
||||
log_warn("warn")
|
||||
log_info("info")
|
||||
|
||||
logfmt({"k1": b"v1", "k2": 1, "k3": "a b"})
|
||||
|
||||
_build_api_url(url="http://www.baidu.com/s?wd=", query="baidu")
|
||||
|
||||
|
||||
def test_openai_response():
|
||||
resp = OpenAIResponse(data=[], headers={"retry-after": 3})
|
||||
assert resp.request_id is None
|
||||
assert resp.retry_after == 3
|
||||
assert resp.operation_location is None
|
||||
assert resp.organization is None
|
||||
assert resp.response_ms is None
|
||||
|
||||
|
||||
def test_proxy():
|
||||
assert _requests_proxies_arg(proxy=None) is None
|
||||
|
||||
proxy = "127.0.0.1:80"
|
||||
assert _requests_proxies_arg(proxy=proxy) == {"http": proxy, "https": proxy}
|
||||
proxy_dict = {"http": proxy}
|
||||
assert _requests_proxies_arg(proxy=proxy_dict) == proxy_dict
|
||||
assert _aiohttp_proxies_arg(proxy_dict) == proxy
|
||||
proxy_dict = {"https": proxy}
|
||||
assert _requests_proxies_arg(proxy=proxy_dict) == proxy_dict
|
||||
assert _aiohttp_proxies_arg(proxy_dict) == proxy
|
||||
|
||||
assert _make_session() is not None
|
||||
|
||||
assert _aiohttp_proxies_arg(None) is None
|
||||
assert _aiohttp_proxies_arg("test") == "test"
|
||||
with pytest.raises(ValueError):
|
||||
_aiohttp_proxies_arg(-1)
|
||||
|
||||
|
||||
def test_parse_stream():
|
||||
assert parse_stream_helper(None) is None
|
||||
assert parse_stream_helper(b"data: [DONE]") is None
|
||||
assert parse_stream_helper(b"data: test") == "test"
|
||||
assert parse_stream_helper(b"test") is None
|
||||
for line in parse_stream([b"data: test"]):
|
||||
assert line == "test"
|
||||
|
||||
|
||||
api_requestor = APIRequestor(base_url="http://www.baidu.com")
|
||||
|
||||
|
||||
def mock_interpret_response(
|
||||
self, result: requests.Response, stream: bool
|
||||
) -> Tuple[Union[bytes, Iterator[Generator]], bytes]:
|
||||
return b"baidu", False
|
||||
|
||||
|
||||
async def mock_interpret_async_response(
|
||||
self, result: aiohttp.ClientResponse, stream: bool
|
||||
) -> Tuple[Union[OpenAIResponse, AsyncGenerator[OpenAIResponse, None]], bool]:
|
||||
return b"baidu", True
|
||||
|
||||
|
||||
def test_requestor_headers():
|
||||
# validate_headers
|
||||
headers = api_requestor._validate_headers(None)
|
||||
assert not headers
|
||||
with pytest.raises(Exception):
|
||||
api_requestor._validate_headers(-1)
|
||||
with pytest.raises(Exception):
|
||||
api_requestor._validate_headers({1: 2})
|
||||
with pytest.raises(Exception):
|
||||
api_requestor._validate_headers({"test": 1})
|
||||
supplied_headers = {"test": "test"}
|
||||
assert api_requestor._validate_headers(supplied_headers) == supplied_headers
|
||||
|
||||
api_requestor.organization = "test"
|
||||
api_requestor.api_version = "test123"
|
||||
api_requestor.api_type = ApiType.OPEN_AI
|
||||
request_id = "test123"
|
||||
headers = api_requestor.request_headers(method="post", extra={}, request_id=request_id)
|
||||
assert headers["LLM-Organization"] == api_requestor.organization
|
||||
assert headers["LLM-Version"] == api_requestor.api_version
|
||||
assert headers["X-Request-Id"] == request_id
|
||||
|
||||
|
||||
def test_api_requestor(mocker):
|
||||
mocker.patch("metagpt.provider.general_api_base.APIRequestor._interpret_response", mock_interpret_response)
|
||||
resp, _, _ = api_requestor.request(method="get", url="/s?wd=baidu")
|
||||
|
||||
resp, _, _ = api_requestor.request(method="post", url="/s?wd=baidu")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_api_requestor(mocker):
|
||||
mocker.patch(
|
||||
"metagpt.provider.general_api_base.APIRequestor._interpret_async_response", mock_interpret_async_response
|
||||
)
|
||||
resp, _, _ = await api_requestor.arequest(method="get", url="/s?wd=baidu")
|
||||
resp, _, _ = await api_requestor.arequest(method="post", url="/s?wd=baidu")
|
||||
33
tests/metagpt/provider/test_general_api_requestor.py
Normal file
33
tests/metagpt/provider/test_general_api_requestor.py
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the unittest of APIRequestor
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.provider.general_api_requestor import (
|
||||
GeneralAPIRequestor,
|
||||
parse_stream,
|
||||
parse_stream_helper,
|
||||
)
|
||||
|
||||
api_requestor = GeneralAPIRequestor(base_url="http://www.baidu.com")
|
||||
|
||||
|
||||
def test_parse_stream():
|
||||
assert parse_stream_helper(None) is None
|
||||
assert parse_stream_helper(b"data: [DONE]") is None
|
||||
assert parse_stream_helper(b"data: test") == b"test"
|
||||
assert parse_stream_helper(b"test") is None
|
||||
for line in parse_stream([b"data: test"]):
|
||||
assert line == b"test"
|
||||
|
||||
|
||||
def test_api_requestor():
|
||||
resp, _, _ = api_requestor.request(method="get", url="/s?wd=baidu")
|
||||
assert b"baidu" in resp
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_api_requestor():
|
||||
resp, _, _ = await api_requestor.arequest(method="get", url="/s?wd=baidu")
|
||||
assert b"baidu" in resp
|
||||
89
tests/metagpt/provider/test_google_gemini_api.py
Normal file
89
tests/metagpt/provider/test_google_gemini_api.py
Normal file
|
|
@ -0,0 +1,89 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the unittest of google gemini api
|
||||
|
||||
from abc import ABC
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
from google.ai import generativelanguage as glm
|
||||
from google.generativeai.types import content_types
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.provider.google_gemini_api import GeminiLLM
|
||||
|
||||
CONFIG.gemini_api_key = "xx"
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockGeminiResponse(ABC):
|
||||
text: str
|
||||
|
||||
|
||||
prompt_msg = "who are you"
|
||||
messages = [{"role": "user", "parts": prompt_msg}]
|
||||
resp_content = "I'm gemini from google"
|
||||
default_resp = MockGeminiResponse(text=resp_content)
|
||||
|
||||
|
||||
def mock_gemini_count_tokens(self, contents: content_types.ContentsType) -> glm.CountTokensResponse:
|
||||
return glm.CountTokensResponse(total_tokens=20)
|
||||
|
||||
|
||||
async def mock_gemini_count_tokens_async(self, contents: content_types.ContentsType) -> glm.CountTokensResponse:
|
||||
return glm.CountTokensResponse(total_tokens=20)
|
||||
|
||||
|
||||
def mock_gemini_generate_content(self, **kwargs) -> MockGeminiResponse:
|
||||
return default_resp
|
||||
|
||||
|
||||
async def mock_gemini_generate_content_async(self, stream: bool = False, **kwargs) -> MockGeminiResponse:
|
||||
if stream:
|
||||
|
||||
class Iterator(object):
|
||||
async def __aiter__(self):
|
||||
yield default_resp
|
||||
|
||||
return Iterator()
|
||||
else:
|
||||
return default_resp
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gemini_acompletion(mocker):
|
||||
mocker.patch("metagpt.provider.google_gemini_api.GeminiGenerativeModel.count_tokens", mock_gemini_count_tokens)
|
||||
mocker.patch(
|
||||
"metagpt.provider.google_gemini_api.GeminiGenerativeModel.count_tokens_async", mock_gemini_count_tokens_async
|
||||
)
|
||||
mocker.patch("google.generativeai.generative_models.GenerativeModel.generate_content", mock_gemini_generate_content)
|
||||
mocker.patch(
|
||||
"google.generativeai.generative_models.GenerativeModel.generate_content_async",
|
||||
mock_gemini_generate_content_async,
|
||||
)
|
||||
|
||||
gemini_gpt = GeminiLLM()
|
||||
|
||||
assert gemini_gpt._user_msg(prompt_msg) == {"role": "user", "parts": [prompt_msg]}
|
||||
assert gemini_gpt._assistant_msg(prompt_msg) == {"role": "model", "parts": [prompt_msg]}
|
||||
|
||||
usage = gemini_gpt.get_usage(messages, resp_content)
|
||||
assert usage == {"prompt_tokens": 20, "completion_tokens": 20}
|
||||
|
||||
resp = gemini_gpt.completion(messages)
|
||||
assert resp == default_resp
|
||||
|
||||
resp = await gemini_gpt.acompletion(messages)
|
||||
assert resp.text == default_resp.text
|
||||
|
||||
resp = await gemini_gpt.aask(prompt_msg, stream=False)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await gemini_gpt.acompletion_text(messages, stream=False)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await gemini_gpt.acompletion_text(messages, stream=True)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await gemini_gpt.aask(prompt_msg)
|
||||
assert resp == resp_content
|
||||
31
tests/metagpt/provider/test_human_provider.py
Normal file
31
tests/metagpt/provider/test_human_provider.py
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the unittest of HumanProvider
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.provider.human_provider import HumanProvider
|
||||
|
||||
resp_content = "test"
|
||||
resp_exit = "exit"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_human_provider(mocker):
|
||||
mocker.patch("builtins.input", lambda _: resp_content)
|
||||
human_provider = HumanProvider()
|
||||
|
||||
resp = human_provider.ask(resp_content)
|
||||
assert resp == resp_content
|
||||
resp = await human_provider.aask(None)
|
||||
assert resp_content == resp
|
||||
|
||||
mocker.patch("builtins.input", lambda _: resp_exit)
|
||||
with pytest.raises(SystemExit):
|
||||
human_provider.ask(resp_exit)
|
||||
|
||||
resp = await human_provider.acompletion([])
|
||||
assert not resp
|
||||
|
||||
resp = await human_provider.acompletion_text([])
|
||||
assert resp == ""
|
||||
14
tests/metagpt/provider/test_metagpt_api.py
Normal file
14
tests/metagpt/provider/test_metagpt_api.py
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/12/28
|
||||
@Author : mashenquan
|
||||
@File : test_metagpt_api.py
|
||||
"""
|
||||
from metagpt.config import LLMProviderEnum
|
||||
from metagpt.llm import LLM
|
||||
|
||||
|
||||
def test_llm():
|
||||
llm = LLM(provider=LLMProviderEnum.METAGPT)
|
||||
assert llm
|
||||
17
tests/metagpt/provider/test_metagpt_llm_api.py
Normal file
17
tests/metagpt/provider/test_metagpt_llm_api.py
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/8/30
|
||||
@Author : mashenquan
|
||||
@File : test_metagpt_llm_api.py
|
||||
"""
|
||||
from metagpt.provider.metagpt_api import MetaGPTLLM
|
||||
|
||||
|
||||
def test_metagpt():
|
||||
llm = MetaGPTLLM()
|
||||
assert llm
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_metagpt()
|
||||
62
tests/metagpt/provider/test_ollama_api.py
Normal file
62
tests/metagpt/provider/test_ollama_api.py
Normal file
|
|
@ -0,0 +1,62 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the unittest of ollama api
|
||||
|
||||
import json
|
||||
from typing import Any, Tuple
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.provider.ollama_api import OllamaLLM
|
||||
|
||||
prompt_msg = "who are you"
|
||||
messages = [{"role": "user", "content": prompt_msg}]
|
||||
|
||||
resp_content = "I'm ollama"
|
||||
default_resp = {"message": {"role": "assistant", "content": resp_content}}
|
||||
|
||||
CONFIG.ollama_api_base = "http://xxx"
|
||||
CONFIG.max_budget = 10
|
||||
|
||||
|
||||
async def mock_ollama_arequest(self, stream: bool = False, **kwargs) -> Tuple[Any, Any, bool]:
|
||||
if stream:
|
||||
|
||||
class Iterator(object):
|
||||
events = [
|
||||
b'{"message": {"role": "assistant", "content": "I\'m ollama"}, "done": false}',
|
||||
b'{"prompt_eval_count": 20, "eval_count": 20, "done": true}',
|
||||
]
|
||||
|
||||
async def __aiter__(self):
|
||||
for event in self.events:
|
||||
yield event
|
||||
|
||||
return Iterator(), None, None
|
||||
else:
|
||||
raw_default_resp = default_resp.copy()
|
||||
raw_default_resp.update({"prompt_eval_count": 20, "eval_count": 20})
|
||||
return json.dumps(raw_default_resp).encode(), None, None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gemini_acompletion(mocker):
|
||||
mocker.patch("metagpt.provider.general_api_requestor.GeneralAPIRequestor.arequest", mock_ollama_arequest)
|
||||
|
||||
ollama_gpt = OllamaLLM()
|
||||
|
||||
resp = await ollama_gpt.acompletion(messages)
|
||||
assert resp["message"]["content"] == default_resp["message"]["content"]
|
||||
|
||||
resp = await ollama_gpt.aask(prompt_msg, stream=False)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await ollama_gpt.acompletion_text(messages, stream=False)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await ollama_gpt.acompletion_text(messages, stream=True)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await ollama_gpt.aask(prompt_msg)
|
||||
assert resp == resp_content
|
||||
95
tests/metagpt/provider/test_open_llm_api.py
Normal file
95
tests/metagpt/provider/test_open_llm_api.py
Normal file
|
|
@ -0,0 +1,95 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
|
||||
import pytest
|
||||
from openai.types.chat.chat_completion import (
|
||||
ChatCompletion,
|
||||
ChatCompletionMessage,
|
||||
Choice,
|
||||
)
|
||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
from openai.types.chat.chat_completion_chunk import Choice as AChoice
|
||||
from openai.types.chat.chat_completion_chunk import ChoiceDelta
|
||||
from openai.types.completion_usage import CompletionUsage
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.provider.open_llm_api import OpenLLM
|
||||
from metagpt.utils.cost_manager import Costs
|
||||
|
||||
CONFIG.max_budget = 10
|
||||
CONFIG.calc_usage = True
|
||||
|
||||
resp_content = "I'm llama2"
|
||||
default_resp = ChatCompletion(
|
||||
id="cmpl-a6652c1bb181caae8dd19ad8",
|
||||
model="llama-v2-13b-chat",
|
||||
object="chat.completion",
|
||||
created=1703302755,
|
||||
choices=[
|
||||
Choice(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message=ChatCompletionMessage(role="assistant", content=resp_content),
|
||||
logprobs=None,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
default_resp_chunk = ChatCompletionChunk(
|
||||
id=default_resp.id,
|
||||
model=default_resp.model,
|
||||
object="chat.completion.chunk",
|
||||
created=default_resp.created,
|
||||
choices=[
|
||||
AChoice(
|
||||
delta=ChoiceDelta(content=resp_content, role="assistant"),
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
logprobs=None,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
prompt_msg = "who are you"
|
||||
messages = [{"role": "user", "content": prompt_msg}]
|
||||
|
||||
|
||||
async def mock_openai_acompletions_create(self, stream: bool = False, **kwargs) -> ChatCompletionChunk:
|
||||
if stream:
|
||||
|
||||
class Iterator(object):
|
||||
async def __aiter__(self):
|
||||
yield default_resp_chunk
|
||||
|
||||
return Iterator()
|
||||
else:
|
||||
return default_resp
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openllm_acompletion(mocker):
|
||||
mocker.patch("openai.resources.chat.completions.AsyncCompletions.create", mock_openai_acompletions_create)
|
||||
|
||||
openllm_gpt = OpenLLM()
|
||||
openllm_gpt.model = "llama-v2-13b-chat"
|
||||
|
||||
openllm_gpt._update_costs(usage=CompletionUsage(prompt_tokens=100, completion_tokens=100, total_tokens=200))
|
||||
assert openllm_gpt.get_costs() == Costs(
|
||||
total_prompt_tokens=100, total_completion_tokens=100, total_cost=0, total_budget=0
|
||||
)
|
||||
|
||||
resp = await openllm_gpt.acompletion(messages)
|
||||
assert resp.choices[0].message.content in resp_content
|
||||
|
||||
resp = await openllm_gpt.aask(prompt_msg, stream=False)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await openllm_gpt.acompletion_text(messages, stream=False)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await openllm_gpt.acompletion_text(messages, stream=True)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await openllm_gpt.aask(prompt_msg)
|
||||
assert resp == resp_content
|
||||
|
|
@ -1,12 +1,17 @@
|
|||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.provider.openai_api import OpenAIGPTAPI
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.provider.openai_api import OpenAILLM
|
||||
from metagpt.schema import UserMessage
|
||||
|
||||
CONFIG.openai_proxy = None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aask_code():
|
||||
llm = OpenAIGPTAPI()
|
||||
llm = OpenAILLM()
|
||||
msg = [{"role": "user", "content": "Write a python hello world code."}]
|
||||
rsp = await llm.aask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"}
|
||||
assert "language" in rsp
|
||||
|
|
@ -16,7 +21,7 @@ async def test_aask_code():
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aask_code_str():
|
||||
llm = OpenAIGPTAPI()
|
||||
llm = OpenAILLM()
|
||||
msg = "Write a python hello world code."
|
||||
rsp = await llm.aask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"}
|
||||
assert "language" in rsp
|
||||
|
|
@ -26,7 +31,7 @@ async def test_aask_code_str():
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aask_code_Message():
|
||||
llm = OpenAIGPTAPI()
|
||||
llm = OpenAILLM()
|
||||
msg = UserMessage("Write a python hello world code.")
|
||||
rsp = await llm.aask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"}
|
||||
assert "language" in rsp
|
||||
|
|
@ -92,3 +97,77 @@ async def test_ask_code_steps2():
|
|||
assert "Step 1" in rsp["code"]
|
||||
assert "Step 2" in rsp["code"]
|
||||
assert "Step 3" in rsp["code"]
|
||||
|
||||
|
||||
class TestOpenAI:
|
||||
@pytest.fixture
|
||||
def config(self):
|
||||
return Mock(
|
||||
openai_api_key="test_key",
|
||||
OPENAI_API_KEY="test_key",
|
||||
openai_base_url="test_url",
|
||||
OPENAI_BASE_URL="test_url",
|
||||
openai_proxy=None,
|
||||
openai_api_type="other",
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def config_azure(self):
|
||||
return Mock(
|
||||
openai_api_key="test_key",
|
||||
OPENAI_API_KEY="test_key",
|
||||
openai_api_version="test_version",
|
||||
openai_base_url="test_url",
|
||||
OPENAI_BASE_URL="test_url",
|
||||
openai_proxy=None,
|
||||
openai_api_type="azure",
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def config_proxy(self):
|
||||
return Mock(
|
||||
openai_api_key="test_key",
|
||||
OPENAI_API_KEY="test_key",
|
||||
openai_base_url="test_url",
|
||||
OPENAI_BASE_URL="test_url",
|
||||
openai_proxy="http://proxy.com",
|
||||
openai_api_type="other",
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def config_azure_proxy(self):
|
||||
return Mock(
|
||||
openai_api_key="test_key",
|
||||
OPENAI_API_KEY="test_key",
|
||||
openai_api_version="test_version",
|
||||
openai_base_url="test_url",
|
||||
OPENAI_BASE_URL="test_url",
|
||||
openai_proxy="http://proxy.com",
|
||||
openai_api_type="azure",
|
||||
)
|
||||
|
||||
def test_make_client_kwargs_without_proxy(self, config):
|
||||
instance = OpenAILLM()
|
||||
instance.config = config
|
||||
kwargs = instance._make_client_kwargs()
|
||||
assert kwargs == {"api_key": "test_key", "base_url": "test_url"}
|
||||
assert "http_client" not in kwargs
|
||||
|
||||
def test_make_client_kwargs_without_proxy_azure(self, config_azure):
|
||||
instance = OpenAILLM()
|
||||
instance.config = config_azure
|
||||
kwargs = instance._make_client_kwargs()
|
||||
assert kwargs == {"api_key": "test_key", "base_url": "test_url"}
|
||||
assert "http_client" not in kwargs
|
||||
|
||||
def test_make_client_kwargs_with_proxy(self, config_proxy):
|
||||
instance = OpenAILLM()
|
||||
instance.config = config_proxy
|
||||
kwargs = instance._make_client_kwargs()
|
||||
assert "http_client" in kwargs
|
||||
|
||||
def test_make_client_kwargs_with_proxy_azure(self, config_azure_proxy):
|
||||
instance = OpenAILLM()
|
||||
instance.config = config_azure_proxy
|
||||
kwargs = instance._make_client_kwargs()
|
||||
assert "http_client" in kwargs
|
||||
|
|
|
|||
|
|
@ -1,11 +1,61 @@
|
|||
from metagpt.logs import logger
|
||||
from metagpt.provider.spark_api import SparkAPI
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the unittest of spark api
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.provider.spark_api import GetMessageFromWeb, SparkLLM
|
||||
|
||||
CONFIG.spark_appid = "xxx"
|
||||
CONFIG.spark_api_secret = "xxx"
|
||||
CONFIG.spark_api_key = "xxx"
|
||||
CONFIG.domain = "xxxxxx"
|
||||
CONFIG.spark_url = "xxxx"
|
||||
|
||||
prompt_msg = "who are you"
|
||||
resp_content = "I'm Spark"
|
||||
|
||||
|
||||
def test_message():
|
||||
llm = SparkAPI()
|
||||
class MockWebSocketApp(object):
|
||||
def __init__(self, ws_url, on_message=None, on_error=None, on_close=None, on_open=None):
|
||||
pass
|
||||
|
||||
logger.info(llm.ask('只回答"收到了"这三个字。'))
|
||||
result = llm.ask('写一篇五百字的日记')
|
||||
logger.info(result)
|
||||
assert len(result) > 100
|
||||
def run_forever(self, sslopt=None):
|
||||
pass
|
||||
|
||||
|
||||
def test_get_msg_from_web(mocker):
|
||||
mocker.patch("websocket.WebSocketApp", MockWebSocketApp)
|
||||
|
||||
get_msg_from_web = GetMessageFromWeb(text=prompt_msg)
|
||||
assert get_msg_from_web.gen_params()["parameter"]["chat"]["domain"] == "xxxxxx"
|
||||
|
||||
ret = get_msg_from_web.run()
|
||||
assert ret == ""
|
||||
|
||||
|
||||
def mock_spark_get_msg_from_web_run(self) -> str:
|
||||
return resp_content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_spark_acompletion(mocker):
|
||||
mocker.patch("metagpt.provider.spark_api.GetMessageFromWeb.run", mock_spark_get_msg_from_web_run)
|
||||
|
||||
spark_gpt = SparkLLM()
|
||||
|
||||
resp = await spark_gpt.acompletion([])
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await spark_gpt.aask(prompt_msg, stream=False)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await spark_gpt.acompletion_text([], stream=False)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await spark_gpt.acompletion_text([], stream=True)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await spark_gpt.aask(prompt_msg)
|
||||
assert resp == resp_content
|
||||
|
|
|
|||
|
|
@ -1,47 +1,89 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the unittest of ZhiPuAIGPTAPI
|
||||
# @Desc : the unittest of ZhiPuAILLM
|
||||
|
||||
import pytest
|
||||
from zhipuai.utils.sse_client import Event
|
||||
|
||||
from metagpt.provider.zhipuai_api import ZhiPuAIGPTAPI
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.provider.zhipuai_api import ZhiPuAILLM
|
||||
|
||||
CONFIG.zhipuai_api_key = "xxx.xxx"
|
||||
|
||||
prompt_msg = "who are you"
|
||||
messages = [{"role": "user", "content": prompt_msg}]
|
||||
|
||||
resp_content = "I'm chatglm-turbo"
|
||||
default_resp = {
|
||||
"code": 200,
|
||||
"data": {
|
||||
"choices": [
|
||||
{"role": "assistant", "content": "I'm chatglm-turbo"}
|
||||
]
|
||||
}
|
||||
"choices": [{"role": "assistant", "content": resp_content}],
|
||||
"usage": {"prompt_tokens": 20, "completion_tokens": 20},
|
||||
},
|
||||
}
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "who are you"}
|
||||
]
|
||||
|
||||
|
||||
def mock_llm_ask(self, messages: list[dict]) -> dict:
|
||||
def mock_zhipuai_invoke(**kwargs) -> dict:
|
||||
return default_resp
|
||||
|
||||
|
||||
def test_zhipuai_completion(mocker):
|
||||
mocker.patch("metagpt.provider.zhipuai_api.ZhiPuAIGPTAPI.completion", mock_llm_ask)
|
||||
|
||||
resp = ZhiPuAIGPTAPI().completion(messages)
|
||||
assert resp["code"] == 200
|
||||
assert "chatglm-turbo" in resp["data"]["choices"][0]["content"]
|
||||
|
||||
|
||||
async def mock_llm_aask(self, messgaes: list[dict], stream: bool = False) -> dict:
|
||||
async def mock_zhipuai_ainvoke(**kwargs) -> dict:
|
||||
return default_resp
|
||||
|
||||
|
||||
async def mock_zhipuai_asse_invoke(**kwargs):
|
||||
class MockResponse(object):
|
||||
async def _aread(self):
|
||||
class Iterator(object):
|
||||
events = [
|
||||
Event(id="xxx", event="add", data=resp_content, retry=0),
|
||||
Event(
|
||||
id="xxx",
|
||||
event="finish",
|
||||
data="",
|
||||
meta='{"usage": {"completion_tokens": 20,"prompt_tokens": 20}}',
|
||||
),
|
||||
]
|
||||
|
||||
async def __aiter__(self):
|
||||
for event in self.events:
|
||||
yield event
|
||||
|
||||
async for chunk in Iterator():
|
||||
yield chunk
|
||||
|
||||
async def async_events(self):
|
||||
async for chunk in self._aread():
|
||||
yield chunk
|
||||
|
||||
return MockResponse()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_zhipuai_acompletion(mocker):
|
||||
mocker.patch("metagpt.provider.zhipuai_api.ZhiPuAIGPTAPI.acompletion_text", mock_llm_aask)
|
||||
mocker.patch("metagpt.provider.zhipuai.zhipu_model_api.ZhiPuModelAPI.invoke", mock_zhipuai_invoke)
|
||||
mocker.patch("metagpt.provider.zhipuai.zhipu_model_api.ZhiPuModelAPI.ainvoke", mock_zhipuai_ainvoke)
|
||||
mocker.patch("metagpt.provider.zhipuai.zhipu_model_api.ZhiPuModelAPI.asse_invoke", mock_zhipuai_asse_invoke)
|
||||
|
||||
resp = await ZhiPuAIGPTAPI().acompletion_text(messages, stream=False)
|
||||
zhipu_gpt = ZhiPuAILLM()
|
||||
|
||||
assert resp["code"] == 200
|
||||
assert "chatglm-turbo" in resp["data"]["choices"][0]["content"]
|
||||
resp = await zhipu_gpt.acompletion(messages)
|
||||
assert resp["data"]["choices"][0]["content"] == resp_content
|
||||
|
||||
resp = await zhipu_gpt.aask(prompt_msg, stream=False)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await zhipu_gpt.acompletion_text(messages, stream=False)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await zhipu_gpt.acompletion_text(messages, stream=True)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await zhipu_gpt.aask(prompt_msg)
|
||||
assert resp == resp_content
|
||||
|
||||
|
||||
def test_zhipuai_proxy():
|
||||
# CONFIG.openai_proxy = "http://127.0.0.1:8080"
|
||||
_ = ZhiPuAILLM()
|
||||
# assert openai.proxy == CONFIG.openai_proxy
|
||||
|
|
|
|||
3
tests/metagpt/provider/zhipuai/__init__.py
Normal file
3
tests/metagpt/provider/zhipuai/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
26
tests/metagpt/provider/zhipuai/test_async_sse_client.py
Normal file
26
tests/metagpt/provider/zhipuai/test_async_sse_client.py
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.provider.zhipuai.async_sse_client import AsyncSSEClient
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_sse_client():
|
||||
class Iterator(object):
|
||||
async def __aiter__(self):
|
||||
yield b"data: test_value"
|
||||
|
||||
async_sse_client = AsyncSSEClient(event_source=Iterator())
|
||||
async for event in async_sse_client.async_events():
|
||||
assert event.data, "test_value"
|
||||
|
||||
class InvalidIterator(object):
|
||||
async def __aiter__(self):
|
||||
yield b"invalid: test_value"
|
||||
|
||||
async_sse_client = AsyncSSEClient(event_source=InvalidIterator())
|
||||
async for event in async_sse_client.async_events():
|
||||
assert not event
|
||||
44
tests/metagpt/provider/zhipuai/test_zhipu_model_api.py
Normal file
44
tests/metagpt/provider/zhipuai/test_zhipu_model_api.py
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
|
||||
from typing import Any, Tuple
|
||||
|
||||
import pytest
|
||||
import zhipuai
|
||||
from zhipuai.model_api.api import InvokeType
|
||||
from zhipuai.utils.http_client import headers as zhipuai_default_headers
|
||||
|
||||
from metagpt.provider.zhipuai.zhipu_model_api import ZhiPuModelAPI
|
||||
|
||||
api_key = "xxx.xxx"
|
||||
zhipuai.api_key = api_key
|
||||
|
||||
default_resp = b'{"result": "test response"}'
|
||||
|
||||
|
||||
async def mock_requestor_arequest(self, **kwargs) -> Tuple[Any, Any, str]:
|
||||
return default_resp, None, None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_zhipu_model_api(mocker):
|
||||
header = ZhiPuModelAPI.get_header()
|
||||
zhipuai_default_headers.update({"Authorization": api_key})
|
||||
assert header == zhipuai_default_headers
|
||||
|
||||
sse_header = ZhiPuModelAPI.get_sse_header()
|
||||
assert len(sse_header["Authorization"]) == 191
|
||||
|
||||
url_prefix, url_suffix = ZhiPuModelAPI.split_zhipu_api_url(InvokeType.SYNC, kwargs={"model": "chatglm_turbo"})
|
||||
assert url_prefix == "https://open.bigmodel.cn/api"
|
||||
assert url_suffix == "/paas/v3/model-api/chatglm_turbo/invoke"
|
||||
|
||||
mocker.patch("metagpt.provider.general_api_requestor.GeneralAPIRequestor.arequest", mock_requestor_arequest)
|
||||
result = await ZhiPuModelAPI.arequest(
|
||||
InvokeType.SYNC, stream=False, method="get", headers={}, kwargs={"model": "chatglm_turbo"}
|
||||
)
|
||||
assert result == default_resp
|
||||
|
||||
result = await ZhiPuModelAPI.ainvoke()
|
||||
assert result["result"] == "test response"
|
||||
|
|
@ -3,12 +3,14 @@
|
|||
"""
|
||||
@Time : 2023/5/12 13:05
|
||||
@Author : alexanderwu
|
||||
@File : mock.py
|
||||
@File : mock_markdown.py
|
||||
"""
|
||||
from metagpt.actions import BossRequirement, WriteDesign, WritePRD, WriteTasks
|
||||
import json
|
||||
|
||||
from metagpt.actions import UserRequirement, WriteDesign, WritePRD, WriteTasks
|
||||
from metagpt.schema import Message
|
||||
|
||||
BOSS_REQUIREMENT = """开发一个基于大语言模型与私有知识库的搜索引擎,希望可以基于大语言模型进行搜索总结"""
|
||||
USER_REQUIREMENT = """开发一个基于大语言模型与私有知识库的搜索引擎,希望可以基于大语言模型进行搜索总结"""
|
||||
|
||||
DETAIL_REQUIREMENT = """需求:开发一个基于LLM(大语言模型)与私有知识库的搜索引擎,希望有几点能力
|
||||
1. 用户可以在私有知识库进行搜索,再根据大语言模型进行总结,输出的结果包括了总结
|
||||
|
|
@ -71,7 +73,7 @@ PRD = '''## 原始需求
|
|||
```
|
||||
'''
|
||||
|
||||
SYSTEM_DESIGN = '''## Python package name
|
||||
SYSTEM_DESIGN = """## Project name
|
||||
```python
|
||||
"smart_search_engine"
|
||||
```
|
||||
|
|
@ -94,7 +96,7 @@ SYSTEM_DESIGN = '''## Python package name
|
|||
]
|
||||
```
|
||||
|
||||
## Data structures and interface definitions
|
||||
## Data structures and interfaces
|
||||
```mermaid
|
||||
classDiagram
|
||||
class Main {
|
||||
|
|
@ -149,10 +151,36 @@ sequenceDiagram
|
|||
S-->>SE: return summary
|
||||
SE-->>M: return summary
|
||||
```
|
||||
'''
|
||||
"""
|
||||
|
||||
JSON_TASKS = {
|
||||
"Logic Analysis": """
|
||||
在这个项目中,所有的模块都依赖于“SearchEngine”类,这是主入口,其他的模块(Index、Ranking和Summary)都通过它交互。另外,"Index"类又依赖于"KnowledgeBase"类,因为它需要从知识库中获取数据。
|
||||
|
||||
- "main.py"包含"Main"类,是程序的入口点,它调用"SearchEngine"进行搜索操作,所以在其他任何模块之前,"SearchEngine"必须首先被定义。
|
||||
- "search.py"定义了"SearchEngine"类,它依赖于"Index"、"Ranking"和"Summary",因此,这些模块需要在"search.py"之前定义。
|
||||
- "index.py"定义了"Index"类,它从"knowledge_base.py"获取数据来创建索引,所以"knowledge_base.py"需要在"index.py"之前定义。
|
||||
- "ranking.py"和"summary.py"相对独立,只需确保在"search.py"之前定义。
|
||||
- "knowledge_base.py"是独立的模块,可以优先开发。
|
||||
- "interface.py"、"user_feedback.py"、"security.py"、"testing.py"和"monitoring.py"看起来像是功能辅助模块,可以在主要功能模块开发完成后并行开发。
|
||||
""",
|
||||
"Task list": [
|
||||
"smart_search_engine/knowledge_base.py",
|
||||
"smart_search_engine/index.py",
|
||||
"smart_search_engine/ranking.py",
|
||||
"smart_search_engine/summary.py",
|
||||
"smart_search_engine/search.py",
|
||||
"smart_search_engine/main.py",
|
||||
"smart_search_engine/interface.py",
|
||||
"smart_search_engine/user_feedback.py",
|
||||
"smart_search_engine/security.py",
|
||||
"smart_search_engine/testing.py",
|
||||
"smart_search_engine/monitoring.py",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
TASKS = '''## Logic Analysis
|
||||
TASKS = """## Logic Analysis
|
||||
|
||||
在这个项目中,所有的模块都依赖于“SearchEngine”类,这是主入口,其他的模块(Index、Ranking和Summary)都通过它交互。另外,"Index"类又依赖于"KnowledgeBase"类,因为它需要从知识库中获取数据。
|
||||
|
||||
|
|
@ -181,7 +209,7 @@ task_list = [
|
|||
]
|
||||
```
|
||||
这个任务列表首先定义了最基础的模块,然后是依赖这些模块的模块,最后是辅助模块。可以根据团队的能力和资源,同时开发多个任务,只要满足依赖关系。例如,在开发"search.py"之前,可以同时开发"knowledge_base.py"、"index.py"、"ranking.py"和"summary.py"。
|
||||
'''
|
||||
"""
|
||||
|
||||
|
||||
TASKS_TOMATO_CLOCK = '''## Required Python third-party packages: Provided in requirements.txt format
|
||||
|
|
@ -224,35 +252,38 @@ task_list = [
|
|||
TASK = """smart_search_engine/knowledge_base.py"""
|
||||
|
||||
STRS_FOR_PARSING = [
|
||||
"""
|
||||
"""
|
||||
## 1
|
||||
```python
|
||||
a
|
||||
```
|
||||
""",
|
||||
"""
|
||||
"""
|
||||
##2
|
||||
```python
|
||||
"a"
|
||||
```
|
||||
""",
|
||||
"""
|
||||
"""
|
||||
## 3
|
||||
```python
|
||||
a = "a"
|
||||
```
|
||||
""",
|
||||
"""
|
||||
"""
|
||||
## 4
|
||||
```python
|
||||
a = 'a'
|
||||
```
|
||||
"""
|
||||
""",
|
||||
]
|
||||
|
||||
|
||||
class MockMessages:
|
||||
req = Message(role="Boss", content=BOSS_REQUIREMENT, cause_by=BossRequirement)
|
||||
req = Message(role="User", content=USER_REQUIREMENT, cause_by=UserRequirement)
|
||||
prd = Message(role="Product Manager", content=PRD, cause_by=WritePRD)
|
||||
system_design = Message(role="Architect", content=SYSTEM_DESIGN, cause_by=WriteDesign)
|
||||
tasks = Message(role="Project Manager", content=TASKS, cause_by=WriteTasks)
|
||||
json_tasks = Message(
|
||||
role="Project Manager", content=json.dumps(JSON_TASKS, ensure_ascii=False), cause_by=WriteTasks
|
||||
)
|
||||
|
|
|
|||
|
|
@ -4,18 +4,41 @@
|
|||
@Time : 2023/5/20 14:37
|
||||
@Author : alexanderwu
|
||||
@File : test_architect.py
|
||||
@Modified By: mashenquan, 2023-11-1. In accordance with Chapter 2.2.1 and 2.2.2 of RFC 116, utilize the new message
|
||||
distribution feature for message handling.
|
||||
"""
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.actions import WriteDesign, WritePRD
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import PRDS_FILE_REPO
|
||||
from metagpt.logs import logger
|
||||
from metagpt.roles import Architect
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.common import any_to_str, awrite
|
||||
from tests.metagpt.roles.mock import MockMessages
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_architect():
|
||||
# Prerequisites
|
||||
filename = uuid.uuid4().hex + ".json"
|
||||
await awrite(CONFIG.git_repo.workdir / PRDS_FILE_REPO / filename, data=MockMessages.prd.content)
|
||||
|
||||
role = Architect()
|
||||
role.recv(MockMessages.req)
|
||||
rsp = await role.handle(MockMessages.prd)
|
||||
rsp = await role.run(with_message=Message(content="", cause_by=WritePRD))
|
||||
logger.info(rsp)
|
||||
assert len(rsp.content) > 0
|
||||
assert rsp.cause_by == any_to_str(WriteDesign)
|
||||
|
||||
# test update
|
||||
rsp = await role.run(with_message=Message(content="", cause_by=WritePRD))
|
||||
assert rsp
|
||||
assert rsp.cause_by == any_to_str(WriteDesign)
|
||||
assert len(rsp.content) > 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
132
tests/metagpt/roles/test_assistant.py
Normal file
132
tests/metagpt/roles/test_assistant.py
Normal file
|
|
@ -0,0 +1,132 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/12/25
|
||||
@Author : mashenquan
|
||||
@File : test_asssistant.py
|
||||
@Desc : Used by AgentStore.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from metagpt.actions.skill_action import SkillAction
|
||||
from metagpt.actions.talk_action import TalkAction
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.memory.brain_memory import BrainMemory
|
||||
from metagpt.roles.assistant import Assistant
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.common import any_to_str
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run():
|
||||
CONFIG.language = "Chinese"
|
||||
|
||||
class Input(BaseModel):
|
||||
memory: BrainMemory
|
||||
language: str
|
||||
agent_description: str
|
||||
cause_by: str
|
||||
|
||||
inputs = [
|
||||
{
|
||||
"memory": {
|
||||
"history": [
|
||||
{
|
||||
"content": "who is tulin",
|
||||
"role": "user",
|
||||
"id": "1",
|
||||
},
|
||||
{"content": "The one who eaten a poison apple.", "role": "assistant"},
|
||||
],
|
||||
"knowledge": [{"content": "tulin is a scientist."}],
|
||||
"last_talk": "Do you have a poison apple?",
|
||||
},
|
||||
"language": "English",
|
||||
"agent_description": "chatterbox",
|
||||
"cause_by": any_to_str(TalkAction),
|
||||
},
|
||||
{
|
||||
"memory": {
|
||||
"history": [
|
||||
{
|
||||
"content": "can you draw me an picture?",
|
||||
"role": "user",
|
||||
"id": "1",
|
||||
},
|
||||
{"content": "Yes, of course. What do you want me to draw", "role": "assistant"},
|
||||
],
|
||||
"knowledge": [{"content": "tulin is a scientist."}],
|
||||
"last_talk": "Draw me an apple.",
|
||||
},
|
||||
"language": "English",
|
||||
"agent_description": "painter",
|
||||
"cause_by": any_to_str(SkillAction),
|
||||
},
|
||||
]
|
||||
CONFIG.agent_skills = [
|
||||
{"id": 1, "name": "text_to_speech", "type": "builtin", "config": {}, "enabled": True},
|
||||
{"id": 2, "name": "text_to_image", "type": "builtin", "config": {}, "enabled": True},
|
||||
{"id": 3, "name": "ai_call", "type": "builtin", "config": {}, "enabled": True},
|
||||
{"id": 3, "name": "data_analysis", "type": "builtin", "config": {}, "enabled": True},
|
||||
{"id": 5, "name": "crawler", "type": "builtin", "config": {"engine": "ddg"}, "enabled": True},
|
||||
{"id": 6, "name": "knowledge", "type": "builtin", "config": {}, "enabled": True},
|
||||
{"id": 6, "name": "web_search", "type": "builtin", "config": {}, "enabled": True},
|
||||
]
|
||||
|
||||
for i in inputs:
|
||||
seed = Input(**i)
|
||||
CONFIG.language = seed.language
|
||||
CONFIG.agent_description = seed.agent_description
|
||||
role = Assistant(language="Chinese")
|
||||
role.memory = seed.memory # Restore historical conversation content.
|
||||
while True:
|
||||
has_action = await role.think()
|
||||
if not has_action:
|
||||
break
|
||||
msg: Message = await role.act()
|
||||
# logger.info(msg)
|
||||
assert msg
|
||||
assert msg.cause_by == seed.cause_by
|
||||
assert msg.content
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"memory",
|
||||
[
|
||||
{
|
||||
"history": [
|
||||
{
|
||||
"content": "can you draw me an picture?",
|
||||
"role": "user",
|
||||
"id": "1",
|
||||
},
|
||||
{"content": "Yes, of course. What do you want me to draw", "role": "assistant"},
|
||||
],
|
||||
"knowledge": [{"content": "tulin is a scientist."}],
|
||||
"last_talk": "Draw me an apple.",
|
||||
}
|
||||
],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_memory(memory):
|
||||
role = Assistant()
|
||||
role.load_memory(memory)
|
||||
|
||||
val = role.get_memory()
|
||||
assert val
|
||||
|
||||
await role.talk("draw apple")
|
||||
|
||||
agent_skills = CONFIG.agent_skills
|
||||
CONFIG.agent_skills = []
|
||||
try:
|
||||
await role.think()
|
||||
finally:
|
||||
CONFIG.agent_skills = agent_skills
|
||||
assert isinstance(role.rc.todo, TalkAction)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
@ -4,44 +4,62 @@
|
|||
@Time : 2023/5/12 10:14
|
||||
@Author : alexanderwu
|
||||
@File : test_engineer.py
|
||||
@Modified By: mashenquan, 2023-11-1. In accordance with Chapter 2.2.1 and 2.2.2 of RFC 116, utilize the new message
|
||||
distribution feature for message handling.
|
||||
"""
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.actions import WriteCode, WriteTasks
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import (
|
||||
PRDS_FILE_REPO,
|
||||
REQUIREMENT_FILENAME,
|
||||
SYSTEM_DESIGN_FILE_REPO,
|
||||
TASK_FILE_REPO,
|
||||
)
|
||||
from metagpt.logs import logger
|
||||
from metagpt.roles.engineer import Engineer
|
||||
from metagpt.utils.common import CodeParser
|
||||
from tests.metagpt.roles.mock import (
|
||||
STRS_FOR_PARSING,
|
||||
TASKS,
|
||||
TASKS_TOMATO_CLOCK,
|
||||
MockMessages,
|
||||
)
|
||||
from metagpt.schema import CodingContext, Message
|
||||
from metagpt.utils.common import CodeParser, any_to_name, any_to_str, aread, awrite
|
||||
from metagpt.utils.file_repository import FileRepository
|
||||
from metagpt.utils.git_repository import ChangeType
|
||||
from tests.metagpt.roles.mock import STRS_FOR_PARSING, TASKS, MockMessages
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_engineer():
|
||||
engineer = Engineer()
|
||||
# Prerequisites
|
||||
rqno = "20231221155954.json"
|
||||
await FileRepository.save_file(REQUIREMENT_FILENAME, content=MockMessages.req.content)
|
||||
await FileRepository.save_file(rqno, relative_path=PRDS_FILE_REPO, content=MockMessages.prd.content)
|
||||
await FileRepository.save_file(
|
||||
rqno, relative_path=SYSTEM_DESIGN_FILE_REPO, content=MockMessages.system_design.content
|
||||
)
|
||||
await FileRepository.save_file(rqno, relative_path=TASK_FILE_REPO, content=MockMessages.json_tasks.content)
|
||||
|
||||
engineer.recv(MockMessages.req)
|
||||
engineer.recv(MockMessages.prd)
|
||||
engineer.recv(MockMessages.system_design)
|
||||
rsp = await engineer.handle(MockMessages.tasks)
|
||||
engineer = Engineer()
|
||||
rsp = await engineer.run(Message(content="", cause_by=WriteTasks))
|
||||
|
||||
logger.info(rsp)
|
||||
assert "all done." == rsp.content
|
||||
assert rsp.cause_by == any_to_str(WriteCode)
|
||||
src_file_repo = CONFIG.git_repo.new_file_repository(CONFIG.src_workspace)
|
||||
assert src_file_repo.changed_files
|
||||
|
||||
|
||||
def test_parse_str():
|
||||
for idx, i in enumerate(STRS_FOR_PARSING):
|
||||
text = CodeParser.parse_str(f"{idx+1}", i)
|
||||
text = CodeParser.parse_str(f"{idx + 1}", i)
|
||||
# logger.info(text)
|
||||
assert text == 'a'
|
||||
assert text == "a"
|
||||
|
||||
|
||||
def test_parse_blocks():
|
||||
tasks = CodeParser.parse_blocks(TASKS)
|
||||
logger.info(tasks.keys())
|
||||
assert 'Task list' in tasks.keys()
|
||||
assert "Task list" in tasks.keys()
|
||||
|
||||
|
||||
target_list = [
|
||||
|
|
@ -60,14 +78,11 @@ target_list = [
|
|||
|
||||
|
||||
def test_parse_file_list():
|
||||
tasks = CodeParser.parse_file_list("任务列表", TASKS)
|
||||
tasks = CodeParser.parse_file_list("Task list", TASKS)
|
||||
logger.info(tasks)
|
||||
assert isinstance(tasks, list)
|
||||
assert target_list == tasks
|
||||
|
||||
file_list = CodeParser.parse_file_list("Task list", TASKS_TOMATO_CLOCK, lang="python")
|
||||
logger.info(file_list)
|
||||
|
||||
|
||||
target_code = """task_list = [
|
||||
"smart_search_engine/knowledge_base.py",
|
||||
|
|
@ -86,7 +101,63 @@ target_code = """task_list = [
|
|||
|
||||
|
||||
def test_parse_code():
|
||||
code = CodeParser.parse_code("任务列表", TASKS, lang="python")
|
||||
code = CodeParser.parse_code("Task list", TASKS, lang="python")
|
||||
logger.info(code)
|
||||
assert isinstance(code, str)
|
||||
assert target_code == code
|
||||
|
||||
|
||||
def test_todo():
|
||||
role = Engineer()
|
||||
assert role.todo == any_to_name(WriteCode)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_coding_context():
|
||||
# Prerequisites
|
||||
demo_path = Path(__file__).parent / "../../data/demo_project"
|
||||
deps = json.loads(await aread(demo_path / "dependencies.json"))
|
||||
dependency = await CONFIG.git_repo.get_dependency()
|
||||
for k, v in deps.items():
|
||||
await dependency.update(k, set(v))
|
||||
data = await aread(demo_path / "system_design.json")
|
||||
rqno = "20231221155954.json"
|
||||
await awrite(CONFIG.git_repo.workdir / SYSTEM_DESIGN_FILE_REPO / rqno, data)
|
||||
data = await aread(demo_path / "tasks.json")
|
||||
await awrite(CONFIG.git_repo.workdir / TASK_FILE_REPO / rqno, data)
|
||||
|
||||
CONFIG.src_workspace = Path(CONFIG.git_repo.workdir) / "game_2048"
|
||||
src_file_repo = CONFIG.git_repo.new_file_repository(relative_path=CONFIG.src_workspace)
|
||||
task_file_repo = CONFIG.git_repo.new_file_repository(relative_path=TASK_FILE_REPO)
|
||||
design_file_repo = CONFIG.git_repo.new_file_repository(relative_path=SYSTEM_DESIGN_FILE_REPO)
|
||||
|
||||
filename = "game.py"
|
||||
ctx_doc = await Engineer._new_coding_doc(
|
||||
filename=filename,
|
||||
src_file_repo=src_file_repo,
|
||||
task_file_repo=task_file_repo,
|
||||
design_file_repo=design_file_repo,
|
||||
dependency=dependency,
|
||||
)
|
||||
assert ctx_doc
|
||||
assert ctx_doc.filename == filename
|
||||
assert ctx_doc.content
|
||||
ctx = CodingContext.model_validate_json(ctx_doc.content)
|
||||
assert ctx.filename == filename
|
||||
assert ctx.design_doc
|
||||
assert ctx.design_doc.content
|
||||
assert ctx.task_doc
|
||||
assert ctx.task_doc.content
|
||||
assert ctx.code_doc
|
||||
|
||||
CONFIG.git_repo.add_change({f"{TASK_FILE_REPO}/{rqno}": ChangeType.UNTRACTED})
|
||||
CONFIG.git_repo.commit("mock env")
|
||||
await src_file_repo.save(filename=filename, content="content")
|
||||
role = Engineer()
|
||||
assert not role.code_todos
|
||||
await role._new_code_actions()
|
||||
assert role.code_todos
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
|
|
@ -9,10 +9,11 @@
|
|||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import pandas as pd
|
||||
import pytest
|
||||
|
||||
from metagpt.roles.invoice_ocr_assistant import InvoiceOCRAssistant
|
||||
from metagpt.const import DATA_PATH, TEST_DATA_PATH
|
||||
from metagpt.roles.invoice_ocr_assistant import InvoiceOCRAssistant, InvoicePath
|
||||
from metagpt.schema import Message
|
||||
|
||||
|
||||
|
|
@ -22,84 +23,35 @@ from metagpt.schema import Message
|
|||
[
|
||||
(
|
||||
"Invoicing date",
|
||||
Path("../../data/invoices/invoice-1.pdf"),
|
||||
Path("../../../data/invoice_table/invoice-1.xlsx"),
|
||||
[
|
||||
{
|
||||
"收款人": "小明",
|
||||
"城市": "深圳市",
|
||||
"总费用/元": 412.00,
|
||||
"开票日期": "2023年02月03日"
|
||||
}
|
||||
]
|
||||
Path("invoices/invoice-1.pdf"),
|
||||
Path("invoice_table/invoice-1.xlsx"),
|
||||
{"收款人": "小明", "城市": "深圳", "总费用/元": 412.00, "开票日期": "2023年02月03日"},
|
||||
),
|
||||
(
|
||||
"Invoicing date",
|
||||
Path("../../data/invoices/invoice-2.png"),
|
||||
Path("../../../data/invoice_table/invoice-2.xlsx"),
|
||||
[
|
||||
{
|
||||
"收款人": "铁头",
|
||||
"城市": "广州市",
|
||||
"总费用/元": 898.00,
|
||||
"开票日期": "2023年03月17日"
|
||||
}
|
||||
]
|
||||
Path("invoices/invoice-2.png"),
|
||||
Path("invoice_table/invoice-2.xlsx"),
|
||||
{"收款人": "铁头", "城市": "广州", "总费用/元": 898.00, "开票日期": "2023年03月17日"},
|
||||
),
|
||||
(
|
||||
"Invoicing date",
|
||||
Path("../../data/invoices/invoice-3.jpg"),
|
||||
Path("../../../data/invoice_table/invoice-3.xlsx"),
|
||||
[
|
||||
{
|
||||
"收款人": "夏天",
|
||||
"城市": "福州市",
|
||||
"总费用/元": 2462.00,
|
||||
"开票日期": "2023年08月26日"
|
||||
}
|
||||
]
|
||||
Path("invoices/invoice-3.jpg"),
|
||||
Path("invoice_table/invoice-3.xlsx"),
|
||||
{"收款人": "夏天", "城市": "福州", "总费用/元": 2462.00, "开票日期": "2023年08月26日"},
|
||||
),
|
||||
(
|
||||
"Invoicing date",
|
||||
Path("../../data/invoices/invoice-4.zip"),
|
||||
Path("../../../data/invoice_table/invoice-4.xlsx"),
|
||||
[
|
||||
{
|
||||
"收款人": "小明",
|
||||
"城市": "深圳市",
|
||||
"总费用/元": 412.00,
|
||||
"开票日期": "2023年02月03日"
|
||||
},
|
||||
{
|
||||
"收款人": "铁头",
|
||||
"城市": "广州市",
|
||||
"总费用/元": 898.00,
|
||||
"开票日期": "2023年03月17日"
|
||||
},
|
||||
{
|
||||
"收款人": "夏天",
|
||||
"城市": "福州市",
|
||||
"总费用/元": 2462.00,
|
||||
"开票日期": "2023年08月26日"
|
||||
}
|
||||
]
|
||||
),
|
||||
]
|
||||
],
|
||||
)
|
||||
async def test_invoice_ocr_assistant(
|
||||
query: str,
|
||||
invoice_path: Path,
|
||||
invoice_table_path: Path,
|
||||
expected_result: list[dict]
|
||||
):
|
||||
invoice_path = Path.cwd() / invoice_path
|
||||
async def test_invoice_ocr_assistant(query: str, invoice_path: Path, invoice_table_path: Path, expected_result: dict):
|
||||
invoice_path = TEST_DATA_PATH / invoice_path
|
||||
role = InvoiceOCRAssistant()
|
||||
await role.run(Message(
|
||||
content=query,
|
||||
instruct_content={"file_path": invoice_path}
|
||||
))
|
||||
invoice_table_path = Path.cwd() / invoice_table_path
|
||||
await role.run(Message(content=query, instruct_content=InvoicePath(file_path=invoice_path)))
|
||||
invoice_table_path = DATA_PATH / invoice_table_path
|
||||
df = pd.read_excel(invoice_table_path)
|
||||
dict_result = df.to_dict(orient='records')
|
||||
assert dict_result == expected_result
|
||||
|
||||
resp = df.to_dict(orient="records")
|
||||
assert isinstance(resp, list)
|
||||
assert len(resp) == 1
|
||||
resp = resp[0]
|
||||
assert expected_result["收款人"] == resp["收款人"]
|
||||
assert expected_result["城市"] in resp["城市"]
|
||||
assert float(expected_result["总费用/元"]) == float(resp["总费用/元"])
|
||||
assert expected_result["开票日期"] == resp["开票日期"]
|
||||
|
|
|
|||
|
|
@ -13,9 +13,9 @@ from tests.metagpt.roles.mock import MockMessages
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_product_manager():
|
||||
async def test_product_manager(new_filename):
|
||||
product_manager = ProductManager()
|
||||
rsp = await product_manager.handle(MockMessages.req)
|
||||
rsp = await product_manager.run(MockMessages.req)
|
||||
logger.info(rsp)
|
||||
assert len(rsp.content) > 0
|
||||
assert "Product Goals" in rsp.content
|
||||
assert rsp.content == MockMessages.req.content
|
||||
|
|
|
|||
|
|
@ -15,5 +15,5 @@ from tests.metagpt.roles.mock import MockMessages
|
|||
@pytest.mark.asyncio
|
||||
async def test_project_manager():
|
||||
project_manager = ProjectManager()
|
||||
rsp = await project_manager.handle(MockMessages.system_design)
|
||||
rsp = await project_manager.run(MockMessages.system_design)
|
||||
logger.info(rsp)
|
||||
|
|
|
|||
|
|
@ -5,3 +5,59 @@
|
|||
@Author : alexanderwu
|
||||
@File : test_qa_engineer.py
|
||||
"""
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
from pydantic import Field
|
||||
|
||||
from metagpt.actions import DebugError, RunCode, WriteTest
|
||||
from metagpt.actions.summarize_code import SummarizeCode
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.environment import Environment
|
||||
from metagpt.roles import QaEngineer
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.common import any_to_str, aread, awrite
|
||||
|
||||
|
||||
async def test_qa():
|
||||
# Prerequisites
|
||||
demo_path = Path(__file__).parent / "../../data/demo_project"
|
||||
CONFIG.src_workspace = Path(CONFIG.git_repo.workdir) / "qa/game_2048"
|
||||
data = await aread(filename=demo_path / "game.py", encoding="utf-8")
|
||||
await awrite(filename=CONFIG.src_workspace / "game.py", data=data, encoding="utf-8")
|
||||
await awrite(filename=Path(CONFIG.git_repo.workdir) / "requirements.txt", data="")
|
||||
|
||||
class MockEnv(Environment):
|
||||
msgs: List[Message] = Field(default_factory=list)
|
||||
|
||||
def publish_message(self, message: Message, peekable: bool = True) -> bool:
|
||||
self.msgs.append(message)
|
||||
return True
|
||||
|
||||
env = MockEnv()
|
||||
|
||||
role = QaEngineer()
|
||||
role.set_env(env)
|
||||
await role.run(with_message=Message(content="", cause_by=SummarizeCode))
|
||||
assert env.msgs
|
||||
assert env.msgs[0].cause_by == any_to_str(WriteTest)
|
||||
msg = env.msgs[0]
|
||||
env.msgs.clear()
|
||||
await role.run(with_message=msg)
|
||||
assert env.msgs
|
||||
assert env.msgs[0].cause_by == any_to_str(RunCode)
|
||||
msg = env.msgs[0]
|
||||
env.msgs.clear()
|
||||
await role.run(with_message=msg)
|
||||
assert env.msgs
|
||||
assert env.msgs[0].cause_by == any_to_str(DebugError)
|
||||
msg = env.msgs[0]
|
||||
env.msgs.clear()
|
||||
role.test_round_allowed = 1
|
||||
rsp = await role.run(with_message=msg)
|
||||
assert "Exceeding" in rsp.content
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
|
|
@ -11,10 +11,12 @@ async def mock_llm_ask(self, prompt: str, system_msgs):
|
|||
if "Please provide up to 2 necessary keywords" in prompt:
|
||||
return '["dataiku", "datarobot"]'
|
||||
elif "Provide up to 4 queries related to your research topic" in prompt:
|
||||
return '["Dataiku machine learning platform", "DataRobot AI platform comparison", ' \
|
||||
return (
|
||||
'["Dataiku machine learning platform", "DataRobot AI platform comparison", '
|
||||
'"Dataiku vs DataRobot features", "Dataiku and DataRobot use cases"]'
|
||||
)
|
||||
elif "sort the remaining search results" in prompt:
|
||||
return '[1,2]'
|
||||
return "[1,2]"
|
||||
elif "Not relevant." in prompt:
|
||||
return "Not relevant" if random() > 0.5 else prompt[-100:]
|
||||
elif "provide a detailed research report" in prompt:
|
||||
|
|
@ -26,7 +28,27 @@ async def mock_llm_ask(self, prompt: str, system_msgs):
|
|||
async def test_researcher(mocker):
|
||||
with TemporaryDirectory() as dirname:
|
||||
topic = "dataiku vs. datarobot"
|
||||
mocker.patch("metagpt.provider.base_gpt_api.BaseGPTAPI.aask", mock_llm_ask)
|
||||
mocker.patch("metagpt.provider.base_llm.BaseLLM.aask", mock_llm_ask)
|
||||
researcher.RESEARCH_PATH = Path(dirname)
|
||||
await researcher.Researcher().run(topic)
|
||||
assert (researcher.RESEARCH_PATH / f"{topic}.md").read_text().startswith("# Research Report")
|
||||
|
||||
|
||||
def test_write_report(mocker):
|
||||
with TemporaryDirectory() as dirname:
|
||||
for i, topic in enumerate(
|
||||
[
|
||||
("1./metagpt"),
|
||||
('2.:"metagpt'),
|
||||
("3.*?<>|metagpt"),
|
||||
("4. metagpt\n"),
|
||||
]
|
||||
):
|
||||
researcher.RESEARCH_PATH = Path(dirname)
|
||||
content = "# Research Report"
|
||||
researcher.Researcher().write_report(topic, content)
|
||||
assert (researcher.RESEARCH_PATH / f"{i+1}. metagpt.md").read_text().startswith("# Research Report")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
22
tests/metagpt/roles/test_role.py
Normal file
22
tests/metagpt/roles/test_role.py
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : unittest of Role
|
||||
import pytest
|
||||
|
||||
from metagpt.llm import HumanProvider
|
||||
from metagpt.roles.role import Role
|
||||
|
||||
|
||||
def test_role_desc():
|
||||
role = Role(profile="Sales", desc="Best Seller")
|
||||
assert role.profile == "Sales"
|
||||
assert role.desc == "Best Seller"
|
||||
|
||||
|
||||
def test_role_human():
|
||||
role = Role(is_human=True)
|
||||
assert isinstance(role.llm, HumanProvider)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
158
tests/metagpt/roles/test_teacher.py
Normal file
158
tests/metagpt/roles/test_teacher.py
Normal file
|
|
@ -0,0 +1,158 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/7/27 13:25
|
||||
@Author : mashenquan
|
||||
@File : test_teacher.py
|
||||
"""
|
||||
import os
|
||||
from typing import Dict, Optional
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from metagpt.config import CONFIG, Config
|
||||
from metagpt.roles.teacher import Teacher
|
||||
from metagpt.schema import Message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skip
|
||||
async def test_init():
|
||||
class Inputs(BaseModel):
|
||||
name: str
|
||||
profile: str
|
||||
goal: str
|
||||
constraints: str
|
||||
desc: str
|
||||
kwargs: Optional[Dict] = None
|
||||
expect_name: str
|
||||
expect_profile: str
|
||||
expect_goal: str
|
||||
expect_constraints: str
|
||||
expect_desc: str
|
||||
|
||||
inputs = [
|
||||
{
|
||||
"name": "Lily{language}",
|
||||
"expect_name": "Lily{language}",
|
||||
"profile": "X {teaching_language}",
|
||||
"expect_profile": "X {teaching_language}",
|
||||
"goal": "Do {something_big}, {language}",
|
||||
"expect_goal": "Do {something_big}, {language}",
|
||||
"constraints": "Do in {key1}, {language}",
|
||||
"expect_constraints": "Do in {key1}, {language}",
|
||||
"kwargs": {},
|
||||
"desc": "aaa{language}",
|
||||
"expect_desc": "aaa{language}",
|
||||
},
|
||||
{
|
||||
"name": "Lily{language}",
|
||||
"expect_name": "LilyCN",
|
||||
"profile": "X {teaching_language}",
|
||||
"expect_profile": "X EN",
|
||||
"goal": "Do {something_big}, {language}",
|
||||
"expect_goal": "Do sleep, CN",
|
||||
"constraints": "Do in {key1}, {language}",
|
||||
"expect_constraints": "Do in HaHa, CN",
|
||||
"kwargs": {"language": "CN", "key1": "HaHa", "something_big": "sleep", "teaching_language": "EN"},
|
||||
"desc": "aaa{language}",
|
||||
"expect_desc": "aaaCN",
|
||||
},
|
||||
]
|
||||
|
||||
env = os.environ.copy()
|
||||
for i in inputs:
|
||||
seed = Inputs(**i)
|
||||
os.environ.clear()
|
||||
os.environ.update(env)
|
||||
CONFIG = Config()
|
||||
CONFIG.set_context(seed.kwargs)
|
||||
print(CONFIG.options)
|
||||
assert bool("language" in seed.kwargs) == bool("language" in CONFIG.options)
|
||||
|
||||
teacher = Teacher(
|
||||
name=seed.name,
|
||||
profile=seed.profile,
|
||||
goal=seed.goal,
|
||||
constraints=seed.constraints,
|
||||
desc=seed.desc,
|
||||
)
|
||||
assert teacher.name == seed.expect_name
|
||||
assert teacher.desc == seed.expect_desc
|
||||
assert teacher.profile == seed.expect_profile
|
||||
assert teacher.goal == seed.expect_goal
|
||||
assert teacher.constraints == seed.expect_constraints
|
||||
assert teacher.course_title == "teaching_plan"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_file_name():
|
||||
class Inputs(BaseModel):
|
||||
lesson_title: str
|
||||
ext: str
|
||||
expect: str
|
||||
|
||||
inputs = [
|
||||
{"lesson_title": "# @344\n12", "ext": ".md", "expect": "_344_12.md"},
|
||||
{"lesson_title": "1#@$%!*&\\/:*?\"<>|\n\t '1", "ext": ".cc", "expect": "1_1.cc"},
|
||||
]
|
||||
for i in inputs:
|
||||
seed = Inputs(**i)
|
||||
result = Teacher.new_file_name(seed.lesson_title, seed.ext)
|
||||
assert result == seed.expect
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run():
|
||||
CONFIG.set_context({"language": "Chinese", "teaching_language": "English"})
|
||||
lesson = """
|
||||
UNIT 1 Making New Friends
|
||||
TOPIC 1 Welcome to China!
|
||||
Section A
|
||||
|
||||
1a Listen and number the following names.
|
||||
Jane Mari Kangkang Michael
|
||||
Look, listen and understand. Then practice the conversation.
|
||||
Work in groups. Introduce yourself using
|
||||
I ’m ... Then practice 1a
|
||||
with your own hometown or the following places.
|
||||
|
||||
1b Listen and number the following names
|
||||
Jane Michael Maria Kangkang
|
||||
1c Work in groups. Introduce yourself using I ’m ... Then practice 1a with your own hometown or the following places.
|
||||
China the USA the UK Hong Kong Beijing
|
||||
|
||||
2a Look, listen and understand. Then practice the conversation
|
||||
Hello!
|
||||
Hello!
|
||||
Hello!
|
||||
Hello! Are you Maria?
|
||||
No, I’m not. I’m Jane.
|
||||
Oh, nice to meet you, Jane
|
||||
Nice to meet you, too.
|
||||
Hi, Maria!
|
||||
Hi, Kangkang!
|
||||
Welcome to China!
|
||||
Thanks.
|
||||
|
||||
2b Work in groups. Make up a conversation with your own name and the
|
||||
following structures.
|
||||
A: Hello! / Good morning! / Hi! I’m ... Are you ... ?
|
||||
B: ...
|
||||
|
||||
3a Listen, say and trace
|
||||
Aa Bb Cc Dd Ee Ff Gg
|
||||
|
||||
3b Listen and number the following letters. Then circle the letters with the same sound as Bb.
|
||||
Aa Bb Cc Dd Ee Ff Gg
|
||||
|
||||
3c Match the big letters with the small ones. Then write them on the lines.
|
||||
"""
|
||||
teacher = Teacher()
|
||||
rsp = await teacher.run(Message(content=lesson))
|
||||
assert rsp
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
@ -5,23 +5,25 @@
|
|||
@Author : Stitch-z
|
||||
@File : test_tutorial_assistant.py
|
||||
"""
|
||||
|
||||
import aiofiles
|
||||
import pytest
|
||||
|
||||
from metagpt.const import TUTORIAL_PATH
|
||||
from metagpt.roles.tutorial_assistant import TutorialAssistant
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
("language", "topic"),
|
||||
[("Chinese", "Write a tutorial about Python")]
|
||||
)
|
||||
@pytest.mark.parametrize(("language", "topic"), [("Chinese", "Write a tutorial about pip")])
|
||||
async def test_tutorial_assistant(language: str, topic: str):
|
||||
topic = "Write a tutorial about MySQL"
|
||||
role = TutorialAssistant(language=language)
|
||||
msg = await role.run(topic)
|
||||
assert TUTORIAL_PATH.exists()
|
||||
filename = msg.content
|
||||
title = filename.split("/")[-1].split(".")[0]
|
||||
async with aiofiles.open(filename, mode="r") as reader:
|
||||
async with aiofiles.open(filename, mode="r", encoding="utf-8") as reader:
|
||||
content = await reader.read()
|
||||
assert content.startswith(f"# {title}")
|
||||
assert "pip" in content
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
|
|
@ -1,22 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Date : 2023/7/22 02:40
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
#
|
||||
from metagpt.team import Team
|
||||
from metagpt.roles import ProductManager
|
||||
|
||||
from tests.metagpt.roles.ui_role import UI
|
||||
|
||||
|
||||
def test_add_ui():
|
||||
ui = UI()
|
||||
assert ui.profile == "UI Design"
|
||||
|
||||
|
||||
async def test_ui_role(idea: str, investment: float = 3.0, n_round: int = 5):
|
||||
"""Run a startup. Be a boss."""
|
||||
company = Team()
|
||||
company.hire([ProductManager(), UI()])
|
||||
company.invest(investment)
|
||||
company.start_project(idea)
|
||||
await company.run(n_round=n_round)
|
||||
|
|
@ -1,276 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Date : 2023/7/15 16:40
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Desc :
|
||||
import os
|
||||
import re
|
||||
from functools import wraps
|
||||
from importlib import import_module
|
||||
|
||||
from metagpt.actions import Action, ActionOutput, WritePRD
|
||||
from metagpt.const import WORKSPACE_ROOT
|
||||
from metagpt.logs import logger
|
||||
from metagpt.roles import Role
|
||||
from metagpt.schema import Message
|
||||
from metagpt.tools.sd_engine import SDEngine
|
||||
|
||||
PROMPT_TEMPLATE = """
|
||||
# Context
|
||||
{context}
|
||||
|
||||
## Format example
|
||||
{format_example}
|
||||
-----
|
||||
Role: You are a UserInterface Designer; the goal is to finish a UI design according to PRD, give a design description, and select specified elements and UI style.
|
||||
Requirements: Based on the context, fill in the following missing information, provide detailed HTML and CSS code
|
||||
Attention: Use '##' to split sections, not '#', and '## <SECTION_NAME>' SHOULD WRITE BEFORE the code and triple quote.
|
||||
|
||||
## UI Design Description:Provide as Plain text, place the design objective here
|
||||
## Selected Elements:Provide as Plain text, up to 5 specified elements, clear and simple
|
||||
## HTML Layout:Provide as Plain text, use standard HTML code
|
||||
## CSS Styles (styles.css):Provide as Plain text,use standard css code
|
||||
## Anything UNCLEAR:Provide as Plain text. Make clear here.
|
||||
|
||||
"""
|
||||
|
||||
FORMAT_EXAMPLE = """
|
||||
|
||||
## UI Design Description
|
||||
```Snake games are classic and addictive games with simple yet engaging elements. Here are the main elements commonly found in snake games ```
|
||||
|
||||
## Selected Elements
|
||||
|
||||
Game Grid: The game grid is a rectangular...
|
||||
|
||||
Snake: The player controls a snake that moves across the grid...
|
||||
|
||||
Food: Food items (often represented as small objects or differently colored blocks)
|
||||
|
||||
Score: The player's score increases each time the snake eats a piece of food. The longer the snake becomes, the higher the score.
|
||||
|
||||
Game Over: The game ends when the snake collides with itself or an obstacle. At this point, the player's final score is displayed, and they are given the option to restart the game.
|
||||
|
||||
|
||||
## HTML Layout
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Snake Game</title>
|
||||
<link rel="stylesheet" href="styles.css">
|
||||
</head>
|
||||
<body>
|
||||
<div class="game-grid">
|
||||
<!-- Snake will be dynamically generated here using JavaScript -->
|
||||
</div>
|
||||
<div class="food">
|
||||
<!-- Food will be dynamically generated here using JavaScript -->
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
|
||||
## CSS Styles (styles.css)
|
||||
body {
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
height: 100vh;
|
||||
margin: 0;
|
||||
background-color: #f0f0f0;
|
||||
}
|
||||
|
||||
.game-grid {
|
||||
width: 400px;
|
||||
height: 400px;
|
||||
display: grid;
|
||||
grid-template-columns: repeat(20, 1fr); /* Adjust to the desired grid size */
|
||||
grid-template-rows: repeat(20, 1fr);
|
||||
gap: 1px;
|
||||
background-color: #222;
|
||||
border: 1px solid #555;
|
||||
}
|
||||
|
||||
.game-grid div {
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
background-color: #444;
|
||||
}
|
||||
|
||||
.snake-segment {
|
||||
background-color: #00cc66; /* Snake color */
|
||||
}
|
||||
|
||||
.food {
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
background-color: #cc3300; /* Food color */
|
||||
position: absolute;
|
||||
}
|
||||
|
||||
/* Optional styles for a simple game over message */
|
||||
.game-over {
|
||||
position: absolute;
|
||||
top: 50%;
|
||||
left: 50%;
|
||||
transform: translate(-50%, -50%);
|
||||
font-size: 24px;
|
||||
font-weight: bold;
|
||||
color: #ff0000;
|
||||
display: none;
|
||||
}
|
||||
|
||||
## Anything UNCLEAR
|
||||
There are no unclear points.
|
||||
|
||||
"""
|
||||
|
||||
OUTPUT_MAPPING = {
|
||||
"UI Design Description": (str, ...),
|
||||
"Selected Elements": (str, ...),
|
||||
"HTML Layout": (str, ...),
|
||||
"CSS Styles (styles.css)": (str, ...),
|
||||
"Anything UNCLEAR": (str, ...),
|
||||
}
|
||||
|
||||
|
||||
def load_engine(func):
|
||||
"""Decorator to load an engine by file name and engine name."""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
file_name, engine_name = func(*args, **kwargs)
|
||||
engine_file = import_module(file_name, package="metagpt")
|
||||
ip_module_cls = getattr(engine_file, engine_name)
|
||||
try:
|
||||
engine = ip_module_cls()
|
||||
except:
|
||||
engine = None
|
||||
|
||||
return engine
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def parse(func):
|
||||
"""Decorator to parse information using regex pattern."""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
context, pattern = func(*args, **kwargs)
|
||||
match = re.search(pattern, context, re.DOTALL)
|
||||
if match:
|
||||
text_info = match.group(1)
|
||||
logger.info(text_info)
|
||||
else:
|
||||
text_info = context
|
||||
logger.info("未找到匹配的内容")
|
||||
|
||||
return text_info
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class UIDesign(Action):
|
||||
"""Class representing the UI Design action."""
|
||||
|
||||
def __init__(self, name, context=None, llm=None):
|
||||
super().__init__(name, context, llm) # 需要调用LLM进一步丰富UI设计的prompt
|
||||
|
||||
@parse
|
||||
def parse_requirement(self, context: str):
|
||||
"""Parse UI Design draft from the context using regex."""
|
||||
pattern = r"## UI Design draft.*?\n(.*?)## Anything UNCLEAR"
|
||||
return context, pattern
|
||||
|
||||
@parse
|
||||
def parse_ui_elements(self, context: str):
|
||||
"""Parse Selected Elements from the context using regex."""
|
||||
pattern = r"## Selected Elements.*?\n(.*?)## HTML Layout"
|
||||
return context, pattern
|
||||
|
||||
@parse
|
||||
def parse_css_code(self, context: str):
|
||||
pattern = r"```css.*?\n(.*?)## Anything UNCLEAR"
|
||||
return context, pattern
|
||||
|
||||
@parse
|
||||
def parse_html_code(self, context: str):
|
||||
pattern = r"```html.*?\n(.*?)```"
|
||||
return context, pattern
|
||||
|
||||
async def draw_icons(self, context, *args, **kwargs):
|
||||
"""Draw icons using SDEngine."""
|
||||
engine = SDEngine()
|
||||
icon_prompts = self.parse_ui_elements(context)
|
||||
icons = icon_prompts.split("\n")
|
||||
icons = [s for s in icons if len(s.strip()) > 0]
|
||||
prompts_batch = []
|
||||
for icon_prompt in icons:
|
||||
# fixme: 添加icon lora
|
||||
prompt = engine.construct_payload(icon_prompt + ".<lora:WZ0710_AW81e-3_30e3b128d64T32_goon0.5>")
|
||||
prompts_batch.append(prompt)
|
||||
await engine.run_t2i(prompts_batch)
|
||||
logger.info("Finish icon design using StableDiffusion API")
|
||||
|
||||
async def _save(self, css_content, html_content):
|
||||
save_dir = WORKSPACE_ROOT / "resources" / "codes"
|
||||
if not os.path.exists(save_dir):
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
# Save CSS and HTML content to files
|
||||
css_file_path = save_dir / "ui_design.css"
|
||||
html_file_path = save_dir / "ui_design.html"
|
||||
|
||||
with open(css_file_path, "w") as css_file:
|
||||
css_file.write(css_content)
|
||||
with open(html_file_path, "w") as html_file:
|
||||
html_file.write(html_content)
|
||||
|
||||
async def run(self, requirements: list[Message], *args, **kwargs) -> ActionOutput:
|
||||
"""Run the UI Design action."""
|
||||
# fixme: update prompt (根据需求细化prompt)
|
||||
context = requirements[-1].content
|
||||
ui_design_draft = self.parse_requirement(context=context)
|
||||
# todo: parse requirements str
|
||||
prompt = PROMPT_TEMPLATE.format(context=ui_design_draft, format_example=FORMAT_EXAMPLE)
|
||||
logger.info(prompt)
|
||||
ui_describe = await self._aask_v1(prompt, "ui_design", OUTPUT_MAPPING)
|
||||
logger.info(ui_describe.content)
|
||||
logger.info(ui_describe.instruct_content)
|
||||
css = self.parse_css_code(context=ui_describe.content)
|
||||
html = self.parse_html_code(context=ui_describe.content)
|
||||
await self._save(css_content=css, html_content=html)
|
||||
await self.draw_icons(ui_describe.content)
|
||||
return ui_describe
|
||||
|
||||
|
||||
class UI(Role):
|
||||
"""Class representing the UI Role."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name="Catherine",
|
||||
profile="UI Design",
|
||||
goal="Finish a workable and good User Interface design based on a product design",
|
||||
constraints="Give clear layout description and use standard icons to finish the design",
|
||||
skills=["SD"],
|
||||
):
|
||||
super().__init__(name, profile, goal, constraints)
|
||||
self.load_skills(skills)
|
||||
self._init_actions([UIDesign])
|
||||
self._watch([WritePRD])
|
||||
|
||||
@load_engine
|
||||
def load_sd_engine(self):
|
||||
"""Load the SDEngine."""
|
||||
file_name = ".tools.sd_engine"
|
||||
engine_name = "SDEngine"
|
||||
return file_name, engine_name
|
||||
|
||||
def load_skills(self, skills):
|
||||
"""Load skills for the UI Role."""
|
||||
# todo: 添加其他出图engine
|
||||
for skill in skills:
|
||||
if skill == "SD":
|
||||
self.sd_engine = self.load_sd_engine()
|
||||
logger.info(f"load skill engine {self.sd_engine}")
|
||||
4
tests/metagpt/serialize_deserialize/__init__.py
Normal file
4
tests/metagpt/serialize_deserialize/__init__.py
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Date : 11/22/2023 11:48 AM
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Desc :
|
||||
32
tests/metagpt/serialize_deserialize/test_action.py
Normal file
32
tests/metagpt/serialize_deserialize/test_action.py
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Date : 11/22/2023 11:48 AM
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Desc :
|
||||
import pytest
|
||||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.llm import LLM
|
||||
|
||||
|
||||
def test_action_serialize():
|
||||
action = Action()
|
||||
ser_action_dict = action.model_dump()
|
||||
assert "name" in ser_action_dict
|
||||
assert "llm" not in ser_action_dict # not export
|
||||
assert "__module_class_name" not in ser_action_dict
|
||||
|
||||
action = Action(name="test")
|
||||
ser_action_dict = action.model_dump()
|
||||
assert "test" in ser_action_dict["name"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_action_deserialize():
|
||||
action = Action()
|
||||
serialized_data = action.model_dump()
|
||||
|
||||
new_action = Action(**serialized_data)
|
||||
|
||||
assert new_action.name == "Action"
|
||||
assert isinstance(new_action.llm, type(LLM()))
|
||||
assert len(await new_action._aask("who are you")) > 0
|
||||
|
|
@ -0,0 +1,28 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Date : 11/26/2023 2:04 PM
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Desc :
|
||||
import pytest
|
||||
|
||||
from metagpt.actions.action import Action
|
||||
from metagpt.roles.architect import Architect
|
||||
|
||||
|
||||
def test_architect_serialize():
|
||||
role = Architect()
|
||||
ser_role_dict = role.model_dump(by_alias=True)
|
||||
assert "name" in ser_role_dict
|
||||
assert "states" in ser_role_dict
|
||||
assert "actions" in ser_role_dict
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_architect_deserialize():
|
||||
role = Architect()
|
||||
ser_role_dict = role.model_dump(by_alias=True)
|
||||
new_role = Architect(**ser_role_dict)
|
||||
# new_role = Architect.deserialize(ser_role_dict)
|
||||
assert new_role.name == "Bob"
|
||||
assert len(new_role.actions) == 1
|
||||
assert isinstance(new_role.actions[0], Action)
|
||||
await new_role.actions[0].run(with_messages="write a cli snake game")
|
||||
90
tests/metagpt/serialize_deserialize/test_environment.py
Normal file
90
tests/metagpt/serialize_deserialize/test_environment.py
Normal file
|
|
@ -0,0 +1,90 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
|
||||
import shutil
|
||||
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.actions.add_requirement import UserRequirement
|
||||
from metagpt.actions.project_management import WriteTasks
|
||||
from metagpt.environment import Environment
|
||||
from metagpt.roles.project_manager import ProjectManager
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.common import any_to_str
|
||||
from tests.metagpt.serialize_deserialize.test_serdeser_base import (
|
||||
ActionOK,
|
||||
ActionRaise,
|
||||
RoleC,
|
||||
serdeser_path,
|
||||
)
|
||||
|
||||
|
||||
def test_env_serialize():
|
||||
env = Environment()
|
||||
ser_env_dict = env.model_dump()
|
||||
assert "roles" in ser_env_dict
|
||||
assert len(ser_env_dict["roles"]) == 0
|
||||
|
||||
|
||||
def test_env_deserialize():
|
||||
env = Environment()
|
||||
env.publish_message(message=Message(content="test env serialize"))
|
||||
ser_env_dict = env.model_dump()
|
||||
new_env = Environment(**ser_env_dict)
|
||||
assert len(new_env.roles) == 0
|
||||
assert len(new_env.history) == 25
|
||||
|
||||
|
||||
def test_environment_serdeser():
|
||||
out_mapping = {"field1": (list[str], ...)}
|
||||
out_data = {"field1": ["field1 value1", "field1 value2"]}
|
||||
ic_obj = ActionNode.create_model_class("prd", out_mapping)
|
||||
|
||||
message = Message(
|
||||
content="prd", instruct_content=ic_obj(**out_data), role="product manager", cause_by=any_to_str(UserRequirement)
|
||||
)
|
||||
|
||||
environment = Environment()
|
||||
role_c = RoleC()
|
||||
environment.add_role(role_c)
|
||||
environment.publish_message(message)
|
||||
|
||||
ser_data = environment.model_dump()
|
||||
assert ser_data["roles"]["Role C"]["name"] == "RoleC"
|
||||
|
||||
new_env: Environment = Environment(**ser_data)
|
||||
assert len(new_env.roles) == 1
|
||||
|
||||
assert list(new_env.roles.values())[0].states == list(environment.roles.values())[0].states
|
||||
assert isinstance(list(environment.roles.values())[0].actions[0], ActionOK)
|
||||
assert type(list(new_env.roles.values())[0].actions[0]) == ActionOK
|
||||
assert type(list(new_env.roles.values())[0].actions[1]) == ActionRaise
|
||||
|
||||
|
||||
def test_environment_serdeser_v2():
|
||||
environment = Environment()
|
||||
pm = ProjectManager()
|
||||
environment.add_role(pm)
|
||||
|
||||
ser_data = environment.model_dump()
|
||||
|
||||
new_env: Environment = Environment(**ser_data)
|
||||
role = new_env.get_role(pm.profile)
|
||||
assert isinstance(role, ProjectManager)
|
||||
assert isinstance(role.actions[0], WriteTasks)
|
||||
assert isinstance(list(new_env.roles.values())[0].actions[0], WriteTasks)
|
||||
|
||||
|
||||
def test_environment_serdeser_save():
|
||||
environment = Environment()
|
||||
role_c = RoleC()
|
||||
|
||||
shutil.rmtree(serdeser_path.joinpath("team"), ignore_errors=True)
|
||||
|
||||
stg_path = serdeser_path.joinpath("team", "environment")
|
||||
environment.add_role(role_c)
|
||||
environment.serialize(stg_path)
|
||||
|
||||
new_env: Environment = Environment.deserialize(stg_path)
|
||||
assert len(new_env.roles) == 1
|
||||
assert type(list(new_env.roles.values())[0].actions[0]) == ActionOK
|
||||
66
tests/metagpt/serialize_deserialize/test_memory.py
Normal file
66
tests/metagpt/serialize_deserialize/test_memory.py
Normal file
|
|
@ -0,0 +1,66 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : unittest of memory
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.actions.add_requirement import UserRequirement
|
||||
from metagpt.actions.design_api import WriteDesign
|
||||
from metagpt.memory.memory import Memory
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.common import any_to_str
|
||||
from tests.metagpt.serialize_deserialize.test_serdeser_base import serdeser_path
|
||||
|
||||
|
||||
def test_memory_serdeser():
|
||||
msg1 = Message(role="Boss", content="write a snake game", cause_by=UserRequirement)
|
||||
|
||||
out_mapping = {"field2": (list[str], ...)}
|
||||
out_data = {"field2": ["field2 value1", "field2 value2"]}
|
||||
ic_obj = ActionNode.create_model_class("system_design", out_mapping)
|
||||
msg2 = Message(
|
||||
role="Architect", instruct_content=ic_obj(**out_data), content="system design content", cause_by=WriteDesign
|
||||
)
|
||||
|
||||
memory = Memory()
|
||||
memory.add_batch([msg1, msg2])
|
||||
ser_data = memory.model_dump()
|
||||
|
||||
new_memory = Memory(**ser_data)
|
||||
assert new_memory.count() == 2
|
||||
new_msg2 = new_memory.get(2)[0]
|
||||
assert isinstance(new_msg2, BaseModel)
|
||||
assert isinstance(new_memory.storage[-1], BaseModel)
|
||||
assert new_memory.storage[-1].cause_by == any_to_str(WriteDesign)
|
||||
assert new_msg2.role == "Boss"
|
||||
|
||||
memory = Memory(storage=[msg1, msg2], index={msg1.cause_by: [msg1], msg2.cause_by: [msg2]})
|
||||
assert memory.count() == 2
|
||||
|
||||
|
||||
def test_memory_serdeser_save():
|
||||
msg1 = Message(role="User", content="write a 2048 game", cause_by=UserRequirement)
|
||||
|
||||
out_mapping = {"field1": (list[str], ...)}
|
||||
out_data = {"field1": ["field1 value1", "field1 value2"]}
|
||||
ic_obj = ActionNode.create_model_class("system_design", out_mapping)
|
||||
msg2 = Message(
|
||||
role="Architect", instruct_content=ic_obj(**out_data), content="system design content", cause_by=WriteDesign
|
||||
)
|
||||
|
||||
memory = Memory()
|
||||
memory.add_batch([msg1, msg2])
|
||||
|
||||
stg_path = serdeser_path.joinpath("team", "environment")
|
||||
memory.serialize(stg_path)
|
||||
assert stg_path.joinpath("memory.json").exists()
|
||||
|
||||
new_memory = Memory.deserialize(stg_path)
|
||||
assert new_memory.count() == 2
|
||||
new_msg2 = new_memory.get(1)[0]
|
||||
assert new_msg2.instruct_content.field1 == ["field1 value1", "field1 value2"]
|
||||
assert new_msg2.cause_by == any_to_str(WriteDesign)
|
||||
assert len(new_memory.index) == 2
|
||||
|
||||
stg_path.joinpath("memory.json").unlink()
|
||||
58
tests/metagpt/serialize_deserialize/test_polymorphic.py
Normal file
58
tests/metagpt/serialize_deserialize/test_polymorphic.py
Normal file
|
|
@ -0,0 +1,58 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : unittest of polymorphic conditions
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, SerializeAsAny
|
||||
|
||||
from metagpt.actions import Action
|
||||
from tests.metagpt.serialize_deserialize.test_serdeser_base import (
|
||||
ActionOKV2,
|
||||
ActionPass,
|
||||
)
|
||||
|
||||
|
||||
class ActionSubClasses(BaseModel):
|
||||
actions: list[SerializeAsAny[Action]] = []
|
||||
|
||||
|
||||
class ActionSubClassesNoSAA(BaseModel):
|
||||
"""without SerializeAsAny"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
actions: list[Action] = []
|
||||
|
||||
|
||||
def test_serialize_as_any():
|
||||
"""test subclasses of action with different fields in ser&deser"""
|
||||
# ActionOKV2 with a extra field `extra_field`
|
||||
action_subcls = ActionSubClasses(actions=[ActionOKV2(), ActionPass()])
|
||||
action_subcls_dict = action_subcls.model_dump()
|
||||
assert action_subcls_dict["actions"][0]["extra_field"] == ActionOKV2().extra_field
|
||||
|
||||
|
||||
def test_no_serialize_as_any():
|
||||
# ActionOKV2 with a extra field `extra_field`
|
||||
action_subcls = ActionSubClassesNoSAA(actions=[ActionOKV2(), ActionPass()])
|
||||
action_subcls_dict = action_subcls.model_dump()
|
||||
# without `SerializeAsAny`, it will serialize as Action
|
||||
assert "extra_field" not in action_subcls_dict["actions"][0]
|
||||
|
||||
|
||||
def test_polymorphic():
|
||||
_ = ActionOKV2(
|
||||
**{"name": "ActionOKV2", "context": "", "prefix": "", "desc": "", "extra_field": "ActionOKV2 Extra Info"}
|
||||
)
|
||||
|
||||
action_subcls = ActionSubClasses(actions=[ActionOKV2(), ActionPass()])
|
||||
action_subcls_dict = action_subcls.model_dump()
|
||||
|
||||
assert "__module_class_name" in action_subcls_dict["actions"][0]
|
||||
|
||||
new_action_subcls = ActionSubClasses(**action_subcls_dict)
|
||||
assert isinstance(new_action_subcls.actions[0], ActionOKV2)
|
||||
assert isinstance(new_action_subcls.actions[1], ActionPass)
|
||||
|
||||
new_action_subcls = ActionSubClasses.model_validate(action_subcls_dict)
|
||||
assert isinstance(new_action_subcls.actions[0], ActionOKV2)
|
||||
assert isinstance(new_action_subcls.actions[1], ActionPass)
|
||||
|
|
@ -0,0 +1,19 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.actions.prepare_interview import PrepareInterview
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_action_deserialize():
|
||||
action = PrepareInterview()
|
||||
serialized_data = action.model_dump()
|
||||
assert serialized_data["name"] == "PrepareInterview"
|
||||
|
||||
new_action = PrepareInterview(**serialized_data)
|
||||
|
||||
assert new_action.name == "PrepareInterview"
|
||||
assert type(await new_action.run("python developer")) == ActionNode
|
||||
21
tests/metagpt/serialize_deserialize/test_product_manager.py
Normal file
21
tests/metagpt/serialize_deserialize/test_product_manager.py
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Date : 11/26/2023 2:07 PM
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Desc :
|
||||
import pytest
|
||||
|
||||
from metagpt.actions.action import Action
|
||||
from metagpt.roles.product_manager import ProductManager
|
||||
from metagpt.schema import Message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_product_manager_deserialize(new_filename):
|
||||
role = ProductManager()
|
||||
ser_role_dict = role.model_dump(by_alias=True)
|
||||
new_role = ProductManager(**ser_role_dict)
|
||||
|
||||
assert new_role.name == "Alice"
|
||||
assert len(new_role.actions) == 2
|
||||
assert isinstance(new_role.actions[0], Action)
|
||||
await new_role.actions[0].run([Message(content="write a cli snake game")])
|
||||
30
tests/metagpt/serialize_deserialize/test_project_manager.py
Normal file
30
tests/metagpt/serialize_deserialize/test_project_manager.py
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Date : 11/26/2023 2:06 PM
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Desc :
|
||||
import pytest
|
||||
|
||||
from metagpt.actions.action import Action
|
||||
from metagpt.actions.project_management import WriteTasks
|
||||
from metagpt.roles.project_manager import ProjectManager
|
||||
|
||||
|
||||
def test_project_manager_serialize():
|
||||
role = ProjectManager()
|
||||
ser_role_dict = role.model_dump(by_alias=True)
|
||||
assert "name" in ser_role_dict
|
||||
assert "states" in ser_role_dict
|
||||
assert "actions" in ser_role_dict
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_project_manager_deserialize():
|
||||
role = ProjectManager()
|
||||
ser_role_dict = role.model_dump(by_alias=True)
|
||||
|
||||
new_role = ProjectManager(**ser_role_dict)
|
||||
assert new_role.name == "Eve"
|
||||
assert len(new_role.actions) == 1
|
||||
assert isinstance(new_role.actions[0], Action)
|
||||
assert isinstance(new_role.actions[0], WriteTasks)
|
||||
# await new_role.actions[0].run(context="write a cli snake game")
|
||||
22
tests/metagpt/serialize_deserialize/test_reasearcher.py
Normal file
22
tests/metagpt/serialize_deserialize/test_reasearcher.py
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.actions import CollectLinks
|
||||
from metagpt.roles.researcher import Researcher
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tutorial_assistant_deserialize():
|
||||
role = Researcher()
|
||||
ser_role_dict = role.model_dump()
|
||||
assert "name" in ser_role_dict
|
||||
assert "language" in ser_role_dict
|
||||
|
||||
new_role = Researcher(**ser_role_dict)
|
||||
assert new_role.language == "en-us"
|
||||
assert len(new_role.actions) == 3
|
||||
assert isinstance(new_role.actions[0], CollectLinks)
|
||||
|
||||
# todo: 需要测试不同的action失败下,记忆是否正常保存
|
||||
120
tests/metagpt/serialize_deserialize/test_role.py
Normal file
120
tests/metagpt/serialize_deserialize/test_role.py
Normal file
|
|
@ -0,0 +1,120 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Date : 11/23/2023 4:49 PM
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Desc :
|
||||
|
||||
import shutil
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel, SerializeAsAny
|
||||
|
||||
from metagpt.actions import WriteCode
|
||||
from metagpt.actions.add_requirement import UserRequirement
|
||||
from metagpt.const import SERDESER_PATH
|
||||
from metagpt.logs import logger
|
||||
from metagpt.roles.engineer import Engineer
|
||||
from metagpt.roles.product_manager import ProductManager
|
||||
from metagpt.roles.role import Role
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.common import format_trackback_info
|
||||
from tests.metagpt.serialize_deserialize.test_serdeser_base import (
|
||||
ActionOK,
|
||||
RoleA,
|
||||
RoleB,
|
||||
RoleC,
|
||||
RoleD,
|
||||
serdeser_path,
|
||||
)
|
||||
|
||||
|
||||
def test_roles():
|
||||
role_a = RoleA()
|
||||
assert len(role_a.rc.watch) == 1
|
||||
role_b = RoleB()
|
||||
assert len(role_a.rc.watch) == 1
|
||||
assert len(role_b.rc.watch) == 1
|
||||
|
||||
role_d = RoleD(actions=[ActionOK()])
|
||||
assert len(role_d.actions) == 1
|
||||
|
||||
|
||||
def test_role_subclasses():
|
||||
"""test subclasses of role with same fields in ser&deser"""
|
||||
|
||||
class RoleSubClasses(BaseModel):
|
||||
roles: list[SerializeAsAny[Role]] = []
|
||||
|
||||
role_subcls = RoleSubClasses(roles=[RoleA(), RoleB()])
|
||||
role_subcls_dict = role_subcls.model_dump()
|
||||
|
||||
new_role_subcls = RoleSubClasses(**role_subcls_dict)
|
||||
assert isinstance(new_role_subcls.roles[0], RoleA)
|
||||
assert isinstance(new_role_subcls.roles[1], RoleB)
|
||||
|
||||
|
||||
def test_role_serialize():
|
||||
role = Role()
|
||||
ser_role_dict = role.model_dump()
|
||||
assert "name" in ser_role_dict
|
||||
assert "states" in ser_role_dict
|
||||
assert "actions" in ser_role_dict
|
||||
|
||||
|
||||
def test_engineer_serialize():
|
||||
role = Engineer()
|
||||
ser_role_dict = role.model_dump()
|
||||
assert "name" in ser_role_dict
|
||||
assert "states" in ser_role_dict
|
||||
assert "actions" in ser_role_dict
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_engineer_deserialize():
|
||||
role = Engineer(use_code_review=True)
|
||||
ser_role_dict = role.model_dump()
|
||||
|
||||
new_role = Engineer(**ser_role_dict)
|
||||
assert new_role.name == "Alex"
|
||||
assert new_role.use_code_review is True
|
||||
assert len(new_role.actions) == 1
|
||||
assert isinstance(new_role.actions[0], WriteCode)
|
||||
# await new_role.actions[0].run(context="write a cli snake game", filename="test_code")
|
||||
|
||||
|
||||
def test_role_serdeser_save():
|
||||
stg_path_prefix = serdeser_path.joinpath("team", "environment", "roles")
|
||||
shutil.rmtree(serdeser_path.joinpath("team"), ignore_errors=True)
|
||||
|
||||
pm = ProductManager()
|
||||
role_tag = f"{pm.__class__.__name__}_{pm.name}"
|
||||
stg_path = stg_path_prefix.joinpath(role_tag)
|
||||
pm.serialize(stg_path)
|
||||
|
||||
new_pm = Role.deserialize(stg_path)
|
||||
assert new_pm.name == pm.name
|
||||
assert len(new_pm.get_memories(1)) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_role_serdeser_interrupt():
|
||||
role_c = RoleC()
|
||||
shutil.rmtree(SERDESER_PATH.joinpath("team"), ignore_errors=True)
|
||||
|
||||
stg_path = SERDESER_PATH.joinpath("team", "environment", "roles", f"{role_c.__class__.__name__}_{role_c.name}")
|
||||
try:
|
||||
await role_c.run(with_message=Message(content="demo", cause_by=UserRequirement))
|
||||
except Exception:
|
||||
logger.error(f"Exception in `role_a.run`, detail: {format_trackback_info()}")
|
||||
role_c.serialize(stg_path)
|
||||
|
||||
assert role_c.rc.memory.count() == 1
|
||||
|
||||
new_role_a: Role = Role.deserialize(stg_path)
|
||||
assert new_role_a.rc.state == 1
|
||||
|
||||
with pytest.raises(Exception):
|
||||
await new_role_a.run(with_message=Message(content="demo", cause_by=UserRequirement))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue