mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-14 15:25:17 +02:00
Merge branch 'dev' into update-unit-test
This commit is contained in:
commit
bf0f6bd272
148 changed files with 6195 additions and 691 deletions
|
|
@ -13,17 +13,17 @@ from unittest.mock import Mock
|
|||
|
||||
import pytest
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.config import CONFIG, Config
|
||||
from metagpt.const import DEFAULT_WORKSPACE_ROOT
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.openai_api import OpenAIGPTAPI as GPTAPI
|
||||
from metagpt.utils.git_repository import GitRepository
|
||||
|
||||
|
||||
class Context:
|
||||
def __init__(self):
|
||||
self._llm_ui = None
|
||||
self._llm_api = GPTAPI()
|
||||
self._llm_api = LLM(provider=CONFIG.get_default_llm_provider_enum())
|
||||
|
||||
@property
|
||||
def llm_api(self):
|
||||
|
|
@ -96,3 +96,8 @@ def setup_and_teardown_git_repo(request):
|
|||
|
||||
# Register the function for destroying the environment.
|
||||
request.addfinalizer(fin)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def init_config():
|
||||
Config()
|
||||
|
|
|
|||
143
tests/metagpt/actions/mock_json.py
Normal file
143
tests/metagpt/actions/mock_json.py
Normal file
|
|
@ -0,0 +1,143 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/12/24 20:32
|
||||
@Author : alexanderwu
|
||||
@File : mock_json.py
|
||||
"""
|
||||
|
||||
PRD = {
|
||||
"Language": "zh_cn",
|
||||
"Programming Language": "Python",
|
||||
"Original Requirements": "写一个简单的cli贪吃蛇",
|
||||
"Project Name": "cli_snake",
|
||||
"Product Goals": ["创建一个简单易用的贪吃蛇游戏", "提供良好的用户体验", "支持不同难度级别"],
|
||||
"User Stories": [
|
||||
"作为玩家,我希望能够选择不同的难度级别",
|
||||
"作为玩家,我希望在每局游戏结束后能够看到我的得分",
|
||||
"作为玩家,我希望在输掉游戏后能够重新开始",
|
||||
"作为玩家,我希望看到简洁美观的界面",
|
||||
"作为玩家,我希望能够在手机上玩游戏",
|
||||
],
|
||||
"Competitive Analysis": ["贪吃蛇游戏A:界面简单,缺乏响应式特性", "贪吃蛇游戏B:美观且响应式的界面,显示最高得分", "贪吃蛇游戏C:响应式界面,显示最高得分,但有很多广告"],
|
||||
"Competitive Quadrant Chart": 'quadrantChart\n title "Reach and engagement of campaigns"\n x-axis "Low Reach" --> "High Reach"\n y-axis "Low Engagement" --> "High Engagement"\n quadrant-1 "We should expand"\n quadrant-2 "Need to promote"\n quadrant-3 "Re-evaluate"\n quadrant-4 "May be improved"\n "Game A": [0.3, 0.6]\n "Game B": [0.45, 0.23]\n "Game C": [0.57, 0.69]\n "Game D": [0.78, 0.34]\n "Game E": [0.40, 0.34]\n "Game F": [0.35, 0.78]\n "Our Target Product": [0.5, 0.6]',
|
||||
"Requirement Analysis": "",
|
||||
"Requirement Pool": [["P0", "主要代码..."], ["P0", "游戏算法..."]],
|
||||
"UI Design draft": "基本功能描述,简单的风格和布局。",
|
||||
"Anything UNCLEAR": "",
|
||||
}
|
||||
|
||||
|
||||
DESIGN = {
|
||||
"Implementation approach": "我们将使用Python编程语言,并选择合适的开源框架来实现贪吃蛇游戏。我们将分析需求中的难点,并选择合适的开源框架来简化开发流程。",
|
||||
"File list": ["main.py", "game.py"],
|
||||
"Data structures and interfaces": "\nclassDiagram\n class Game {\n -int width\n -int height\n -int score\n -int speed\n -List<Point> snake\n -Point food\n +__init__(width: int, height: int, speed: int)\n +start_game()\n +change_direction(direction: str)\n +game_over()\n +update_snake()\n +update_food()\n +check_collision()\n }\n class Point {\n -int x\n -int y\n +__init__(x: int, y: int)\n }\n Game --> Point\n",
|
||||
"Program call flow": "\nsequenceDiagram\n participant M as Main\n participant G as Game\n M->>G: start_game()\n M->>G: change_direction(direction)\n G->>G: update_snake()\n G->>G: update_food()\n G->>G: check_collision()\n G-->>G: game_over()\n",
|
||||
"Anything UNCLEAR": "",
|
||||
}
|
||||
|
||||
|
||||
TASKS = {
|
||||
"Required Python packages": ["pygame==2.0.1"],
|
||||
"Required Other language third-party packages": ["No third-party dependencies required"],
|
||||
"Logic Analysis": [
|
||||
["game.py", "Contains Game class and related functions for game logic"],
|
||||
["main.py", "Contains the main function, imports Game class from game.py"],
|
||||
],
|
||||
"Task list": ["game.py", "main.py"],
|
||||
"Full API spec": "",
|
||||
"Shared Knowledge": "'game.py' contains functions shared across the project.",
|
||||
"Anything UNCLEAR": "",
|
||||
}
|
||||
|
||||
|
||||
FILE_GAME = """## game.py
|
||||
|
||||
import pygame
|
||||
import random
|
||||
|
||||
class Point:
|
||||
def __init__(self, x: int, y: int):
|
||||
self.x = x
|
||||
self.y = y
|
||||
|
||||
class Game:
|
||||
def __init__(self, width: int, height: int, speed: int):
|
||||
self.width = width
|
||||
self.height = height
|
||||
self.score = 0
|
||||
self.speed = speed
|
||||
self.snake = [Point(width // 2, height // 2)]
|
||||
self.food = self._create_food()
|
||||
|
||||
def start_game(self):
|
||||
pygame.init()
|
||||
self._display = pygame.display.set_mode((self.width, self.height))
|
||||
pygame.display.set_caption('Snake Game')
|
||||
self._clock = pygame.time.Clock()
|
||||
self._running = True
|
||||
|
||||
while self._running:
|
||||
self._handle_events()
|
||||
self._update_snake()
|
||||
self._update_food()
|
||||
self._check_collision()
|
||||
self._draw_screen()
|
||||
self._clock.tick(self.speed)
|
||||
|
||||
def change_direction(self, direction: str):
|
||||
# Update the direction of the snake based on user input
|
||||
pass
|
||||
|
||||
def game_over(self):
|
||||
# Display game over message and handle game over logic
|
||||
pass
|
||||
|
||||
def _create_food(self) -> Point:
|
||||
# Create and return a new food Point
|
||||
return Point(random.randint(0, self.width - 1), random.randint(0, self.height - 1))
|
||||
|
||||
def _handle_events(self):
|
||||
for event in pygame.event.get():
|
||||
if event.type == pygame.QUIT:
|
||||
self._running = False
|
||||
|
||||
def _update_snake(self):
|
||||
# Update the position of the snake based on its direction
|
||||
pass
|
||||
|
||||
def _update_food(self):
|
||||
# Update the position of the food if the snake eats it
|
||||
pass
|
||||
|
||||
def _check_collision(self):
|
||||
# Check for collision between the snake and the walls or itself
|
||||
pass
|
||||
|
||||
def _draw_screen(self):
|
||||
self._display.fill((0, 0, 0)) # Clear the screen
|
||||
# Draw the snake and food on the screen
|
||||
pygame.display.update()
|
||||
|
||||
if __name__ == "__main__":
|
||||
game = Game(800, 600, 15)
|
||||
game.start_game()
|
||||
"""
|
||||
|
||||
FILE_GAME_CR_1 = """## Code Review: game.py
|
||||
1. Yes, the code is implemented as per the requirements. It initializes the game with the specified width, height, and speed, and starts the game loop.
|
||||
2. No, the logic for handling events and updating the snake, food, and collision is not implemented. To correct this, we need to implement the logic for handling events, updating the snake and food positions, and checking for collisions.
|
||||
3. Yes, the existing code follows the "Data structures and interfaces" by defining the Game and Point classes with the specified attributes and methods.
|
||||
4. No, several functions such as change_direction, game_over, _update_snake, _update_food, and _check_collision are not implemented. These functions need to be implemented to complete the game logic.
|
||||
5. Yes, all necessary pre-dependencies have been imported. The required pygame package is imported at the beginning of the file.
|
||||
6. No, methods from other files are not being reused as there are no other files being imported or referenced in the current code.
|
||||
|
||||
## Actions
|
||||
1. Implement the logic for handling events, updating the snake and food positions, and checking for collisions within the Game class.
|
||||
2. Implement the change_direction and game_over methods to handle user input and game over logic.
|
||||
3. Implement the _update_snake method to update the position of the snake based on its direction.
|
||||
4. Implement the _update_food method to update the position of the food if the snake eats it.
|
||||
5. Implement the _check_collision method to check for collision between the snake and the walls or itself.
|
||||
|
||||
## Code Review Result
|
||||
LBTM"""
|
||||
|
|
@ -3,7 +3,7 @@
|
|||
"""
|
||||
@Time : 2023/5/18 23:51
|
||||
@Author : alexanderwu
|
||||
@File : mock.py
|
||||
@File : mock_markdown.py
|
||||
"""
|
||||
|
||||
PRD_SAMPLE = """## Original Requirements
|
||||
|
|
@ -5,9 +5,16 @@
|
|||
@Author : alexanderwu
|
||||
@File : test_action.py
|
||||
"""
|
||||
from metagpt.actions import Action, WritePRD, WriteTest
|
||||
from metagpt.actions import Action, ActionType, WritePRD, WriteTest
|
||||
|
||||
|
||||
def test_action_repr():
|
||||
actions = [Action(), WriteTest(), WritePRD()]
|
||||
assert "WriteTest" in str(actions)
|
||||
|
||||
|
||||
def test_action_type():
|
||||
assert ActionType.WRITE_PRD.value == WritePRD
|
||||
assert ActionType.WRITE_TEST.value == WriteTest
|
||||
assert ActionType.WRITE_PRD.name == "WRITE_PRD"
|
||||
assert ActionType.WRITE_TEST.name == "WRITE_TEST"
|
||||
|
|
|
|||
|
|
@ -5,6 +5,8 @@
|
|||
@Author : alexanderwu
|
||||
@File : test_action_node.py
|
||||
"""
|
||||
from typing import List, Tuple
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.actions import Action
|
||||
|
|
@ -29,7 +31,7 @@ async def test_debate_two_roles():
|
|||
team = Team(investment=10.0, env=env, roles=[biden, trump])
|
||||
|
||||
history = await team.run(idea="Topic: climate change. Under 80 words per message.", send_to="Biden", n_round=3)
|
||||
assert "BidenSay" in history
|
||||
assert "Biden" in history
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -39,7 +41,7 @@ async def test_debate_one_role_in_env():
|
|||
env = Environment(desc="US election live broadcast")
|
||||
team = Team(investment=10.0, env=env, roles=[biden])
|
||||
history = await team.run(idea="Topic: climate change. Under 80 words per message.", send_to="Biden", n_round=3)
|
||||
assert "Debate" in history
|
||||
assert "Biden" in history
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -86,3 +88,47 @@ async def test_action_node_two_layer():
|
|||
assert node_b in root.children.values()
|
||||
json_template = root.compile(context="123", schema="json", mode="auto")
|
||||
assert "i-a" in json_template
|
||||
|
||||
|
||||
t_dict = {
|
||||
"Required Python third-party packages": '"""\nflask==1.1.2\npygame==2.0.1\n"""\n',
|
||||
"Required Other language third-party packages": '"""\nNo third-party packages required for other languages.\n"""\n',
|
||||
"Full API spec": '"""\nopenapi: 3.0.0\ninfo:\n title: Web Snake Game API\n version: 1.0.0\npaths:\n /game:\n get:\n summary: Get the current game state\n responses:\n \'200\':\n description: A JSON object of the game state\n post:\n summary: Send a command to the game\n requestBody:\n required: true\n content:\n application/json:\n schema:\n type: object\n properties:\n command:\n type: string\n responses:\n \'200\':\n description: A JSON object of the updated game state\n"""\n',
|
||||
"Logic Analysis": [
|
||||
["app.py", "Main entry point for the Flask application. Handles HTTP requests and responses."],
|
||||
["game.py", "Contains the Game and Snake classes. Handles the game logic."],
|
||||
["static/js/script.js", "Handles user interactions and updates the game UI."],
|
||||
["static/css/styles.css", "Defines the styles for the game UI."],
|
||||
["templates/index.html", "The main page of the web application. Displays the game UI."],
|
||||
],
|
||||
"Task list": ["game.py", "app.py", "static/css/styles.css", "static/js/script.js", "templates/index.html"],
|
||||
"Shared Knowledge": "\"\"\"\n'game.py' contains the Game and Snake classes which are responsible for the game logic. The Game class uses an instance of the Snake class.\n\n'app.py' is the main entry point for the Flask application. It creates an instance of the Game class and handles HTTP requests and responses.\n\n'static/js/script.js' is responsible for handling user interactions and updating the game UI based on the game state returned by 'app.py'.\n\n'static/css/styles.css' defines the styles for the game UI.\n\n'templates/index.html' is the main page of the web application. It displays the game UI and loads 'static/js/script.js' and 'static/css/styles.css'.\n\"\"\"\n",
|
||||
"Anything UNCLEAR": "We need clarification on how the high score should be stored. Should it persist across sessions (stored in a database or a file) or should it reset every time the game is restarted? Also, should the game speed increase as the snake grows, or should it remain constant throughout the game?",
|
||||
}
|
||||
|
||||
WRITE_TASKS_OUTPUT_MAPPING = {
|
||||
"Required Python third-party packages": (str, ...),
|
||||
"Required Other language third-party packages": (str, ...),
|
||||
"Full API spec": (str, ...),
|
||||
"Logic Analysis": (List[Tuple[str, str]], ...),
|
||||
"Task list": (List[str], ...),
|
||||
"Shared Knowledge": (str, ...),
|
||||
"Anything UNCLEAR": (str, ...),
|
||||
}
|
||||
|
||||
|
||||
def test_create_model_class():
|
||||
test_class = ActionNode.create_model_class("test_class", WRITE_TASKS_OUTPUT_MAPPING)
|
||||
assert test_class.__name__ == "test_class"
|
||||
|
||||
|
||||
def test_create_model_class_with_mapping():
|
||||
t = ActionNode.create_model_class("test_class_1", WRITE_TASKS_OUTPUT_MAPPING)
|
||||
t1 = t(**t_dict)
|
||||
value = t1.dict()["Task list"]
|
||||
assert value == ["game.py", "app.py", "static/css/styles.css", "static/js/script.js", "templates/index.html"]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_create_model_class()
|
||||
test_create_model_class_with_mapping()
|
||||
|
|
|
|||
|
|
@ -1,53 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# coding: utf-8
|
||||
"""
|
||||
@Time : 2023/7/11 10:49
|
||||
@Author : chengmaoyu
|
||||
@File : test_action_output
|
||||
"""
|
||||
from typing import List, Tuple
|
||||
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
|
||||
t_dict = {
|
||||
"Required Python third-party packages": '"""\nflask==1.1.2\npygame==2.0.1\n"""\n',
|
||||
"Required Other language third-party packages": '"""\nNo third-party packages required for other languages.\n"""\n',
|
||||
"Full API spec": '"""\nopenapi: 3.0.0\ninfo:\n title: Web Snake Game API\n version: 1.0.0\npaths:\n /game:\n get:\n summary: Get the current game state\n responses:\n \'200\':\n description: A JSON object of the game state\n post:\n summary: Send a command to the game\n requestBody:\n required: true\n content:\n application/json:\n schema:\n type: object\n properties:\n command:\n type: string\n responses:\n \'200\':\n description: A JSON object of the updated game state\n"""\n',
|
||||
"Logic Analysis": [
|
||||
["app.py", "Main entry point for the Flask application. Handles HTTP requests and responses."],
|
||||
["game.py", "Contains the Game and Snake classes. Handles the game logic."],
|
||||
["static/js/script.js", "Handles user interactions and updates the game UI."],
|
||||
["static/css/styles.css", "Defines the styles for the game UI."],
|
||||
["templates/index.html", "The main page of the web application. Displays the game UI."],
|
||||
],
|
||||
"Task list": ["game.py", "app.py", "static/css/styles.css", "static/js/script.js", "templates/index.html"],
|
||||
"Shared Knowledge": "\"\"\"\n'game.py' contains the Game and Snake classes which are responsible for the game logic. The Game class uses an instance of the Snake class.\n\n'app.py' is the main entry point for the Flask application. It creates an instance of the Game class and handles HTTP requests and responses.\n\n'static/js/script.js' is responsible for handling user interactions and updating the game UI based on the game state returned by 'app.py'.\n\n'static/css/styles.css' defines the styles for the game UI.\n\n'templates/index.html' is the main page of the web application. It displays the game UI and loads 'static/js/script.js' and 'static/css/styles.css'.\n\"\"\"\n",
|
||||
"Anything UNCLEAR": "We need clarification on how the high score should be stored. Should it persist across sessions (stored in a database or a file) or should it reset every time the game is restarted? Also, should the game speed increase as the snake grows, or should it remain constant throughout the game?",
|
||||
}
|
||||
|
||||
WRITE_TASKS_OUTPUT_MAPPING = {
|
||||
"Required Python third-party packages": (str, ...),
|
||||
"Required Other language third-party packages": (str, ...),
|
||||
"Full API spec": (str, ...),
|
||||
"Logic Analysis": (List[Tuple[str, str]], ...),
|
||||
"Task list": (List[str], ...),
|
||||
"Shared Knowledge": (str, ...),
|
||||
"Anything UNCLEAR": (str, ...),
|
||||
}
|
||||
|
||||
|
||||
def test_create_model_class():
|
||||
test_class = ActionNode.create_model_class("test_class", WRITE_TASKS_OUTPUT_MAPPING)
|
||||
assert test_class.__name__ == "test_class"
|
||||
|
||||
|
||||
def test_create_model_class_with_mapping():
|
||||
t = ActionNode.create_model_class("test_class_1", WRITE_TASKS_OUTPUT_MAPPING)
|
||||
t1 = t(**t_dict)
|
||||
value = t1.dict()["Task list"]
|
||||
assert value == ["game.py", "app.py", "static/css/styles.css", "static/js/script.js", "templates/index.html"]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_create_model_class()
|
||||
test_create_model_class_with_mapping()
|
||||
|
|
@ -1,6 +1,13 @@
|
|||
import os
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.actions.clone_function import CloneFunction, run_function_code
|
||||
from metagpt.actions.clone_function import (
|
||||
CloneFunction,
|
||||
run_function_code,
|
||||
run_function_script,
|
||||
)
|
||||
|
||||
source_code = """
|
||||
import pandas as pd
|
||||
|
|
@ -55,3 +62,40 @@ async def test_clone_function():
|
|||
assert not msg
|
||||
expected_df = get_expected_res()
|
||||
assert df.equals(expected_df)
|
||||
|
||||
|
||||
def test_run_function_script():
|
||||
# 创建一个临时文件并写入脚本内容
|
||||
script_content = """def valid_function(arg1, arg2):\n return arg1 + arg2\n"""
|
||||
with tempfile.NamedTemporaryFile(mode="w+", suffix=".py", delete=False) as temp_file:
|
||||
temp_file.write(script_content)
|
||||
temp_file_path = temp_file.name
|
||||
|
||||
invalid_script_content = """def valid_function(arg1, arg2)\n return arg1 + arg2\n"""
|
||||
with tempfile.NamedTemporaryFile(mode="w+", suffix=".py", delete=False) as error_temp_file:
|
||||
error_temp_file.write(invalid_script_content)
|
||||
error_temp_file_path = error_temp_file.name
|
||||
|
||||
try:
|
||||
# 正常情况下运行脚本
|
||||
result, _ = run_function_script(temp_file_path, "valid_function", 1, arg2=2)
|
||||
assert result == 3
|
||||
|
||||
# 不存在的脚本路径
|
||||
with pytest.raises(FileNotFoundError):
|
||||
run_function_script("nonexistent/path/script.py", "valid_function", 1, arg2=2)
|
||||
|
||||
# 无效的脚本内容
|
||||
result, traceback = run_function_script(error_temp_file_path, "invalid_function", 1, arg2=2)
|
||||
assert not result
|
||||
assert "SyntaxError" in traceback
|
||||
|
||||
# 函数调用失败的情况
|
||||
result, traceback = run_function_script(temp_file_path, "function_that_raises_exception", 1, arg2=2)
|
||||
assert not result
|
||||
assert "KeyError" in traceback
|
||||
|
||||
finally:
|
||||
# 删除临时文件
|
||||
if os.path.exists(temp_file_path):
|
||||
os.remove(temp_file_path)
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from metagpt.const import PRDS_FILE_REPO
|
|||
from metagpt.logs import logger
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.file_repository import FileRepository
|
||||
from tests.metagpt.actions.mock import PRD_SAMPLE
|
||||
from tests.metagpt.actions.mock_markdown import PRD_SAMPLE
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -22,9 +22,9 @@ async def test_design_api():
|
|||
for prd in inputs:
|
||||
await FileRepository.save_file("new_prd.txt", content=prd, relative_path=PRDS_FILE_REPO)
|
||||
|
||||
design_api = WriteDesign("design_api")
|
||||
design_api = WriteDesign()
|
||||
|
||||
result = await design_api.run([Message(content=prd, instruct_content=None)])
|
||||
result = await design_api.run(Message(content=prd, instruct_content=None))
|
||||
logger.info(result)
|
||||
|
||||
assert result
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ API列表:
|
|||
"""
|
||||
_ = "API设计看起来非常合理,满足了PRD中的所有需求。"
|
||||
|
||||
design_api_review = DesignReview("design_api_review")
|
||||
design_api_review = DesignReview()
|
||||
|
||||
result = await design_api_review.run(prd, api_design)
|
||||
|
||||
|
|
|
|||
17
tests/metagpt/actions/test_fix_bug.py
Normal file
17
tests/metagpt/actions/test_fix_bug.py
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/12/25 22:38
|
||||
@Author : alexanderwu
|
||||
@File : test_fix_bug.py
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.actions.fix_bug import FixBug
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fix_bug():
|
||||
fix_bug = FixBug()
|
||||
assert fix_bug.name == "FixBug"
|
||||
|
|
@ -6,10 +6,26 @@
|
|||
@File : test_project_management.py
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
class TestCreateProjectPlan:
|
||||
pass
|
||||
from metagpt.actions.project_management import WriteTasks
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import PRDS_FILE_REPO, SYSTEM_DESIGN_FILE_REPO
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.file_repository import FileRepository
|
||||
from tests.metagpt.actions.mock_json import DESIGN, PRD
|
||||
|
||||
|
||||
class TestAssignTasks:
|
||||
pass
|
||||
@pytest.mark.asyncio
|
||||
async def test_design_api():
|
||||
await FileRepository.save_file("1.txt", content=str(PRD), relative_path=PRDS_FILE_REPO)
|
||||
await FileRepository.save_file("1.txt", content=str(DESIGN), relative_path=SYSTEM_DESIGN_FILE_REPO)
|
||||
logger.info(CONFIG.git_repo)
|
||||
|
||||
action = WriteTasks()
|
||||
|
||||
result = await action.run(Message(content="", instruct_content=None))
|
||||
logger.info(result)
|
||||
|
||||
assert result
|
||||
|
|
|
|||
24
tests/metagpt/actions/test_rebuild_class_view.py
Normal file
24
tests/metagpt/actions/test_rebuild_class_view.py
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/12/20
|
||||
@Author : mashenquan
|
||||
@File : test_rebuild_class_view.py
|
||||
@Desc : Unit tests for rebuild_class_view.py
|
||||
"""
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.actions.rebuild_class_view import RebuildClassView
|
||||
from metagpt.llm import LLM
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rebuild():
|
||||
action = RebuildClassView(name="RedBean", context=Path(__file__).parent.parent, llm=LLM())
|
||||
await action.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
65
tests/metagpt/actions/test_skill_action.py
Normal file
65
tests/metagpt/actions/test_skill_action.py
Normal file
|
|
@ -0,0 +1,65 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/9/19
|
||||
@Author : mashenquan
|
||||
@File : test_skill_action.py
|
||||
@Desc : Unit tests.
|
||||
"""
|
||||
import pytest
|
||||
|
||||
from metagpt.actions.skill_action import ArgumentsParingAction, SkillAction
|
||||
from metagpt.learn.skill_loader import Example, Parameter, Returns, Skill
|
||||
|
||||
|
||||
class TestSkillAction:
|
||||
skill = Skill(
|
||||
name="text_to_image",
|
||||
description="Create a drawing based on the text.",
|
||||
id="text_to_image.text_to_image",
|
||||
x_prerequisite={
|
||||
"configurations": {
|
||||
"OPENAI_API_KEY": {
|
||||
"type": "string",
|
||||
"description": "OpenAI API key, For more details, checkout: `https://platform.openai.com/account/api-keys`",
|
||||
},
|
||||
"METAGPT_TEXT_TO_IMAGE_MODEL_URL": {"type": "string", "description": "Model url."},
|
||||
},
|
||||
"required": {"oneOf": ["OPENAI_API_KEY", "METAGPT_TEXT_TO_IMAGE_MODEL_URL"]},
|
||||
},
|
||||
parameters={
|
||||
"text": Parameter(type="string", description="The text used for image conversion."),
|
||||
"size_type": Parameter(type="string", description="size type"),
|
||||
},
|
||||
examples=[
|
||||
Example(ask="Draw a girl", answer='text_to_image(text="Draw a girl", size_type="512x512")'),
|
||||
Example(ask="Draw an apple", answer='text_to_image(text="Draw an apple", size_type="512x512")'),
|
||||
],
|
||||
returns=Returns(type="string", format="base64"),
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parser(self):
|
||||
args = ArgumentsParingAction.parse_arguments(
|
||||
skill_name="text_to_image", txt='`text_to_image(text="Draw an apple", size_type="512x512")`'
|
||||
)
|
||||
assert args.get("text") == "Draw an apple"
|
||||
assert args.get("size_type") == "512x512"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parser_action(self):
|
||||
parser_action = ArgumentsParingAction(skill=self.skill, ask="Draw an apple")
|
||||
rsp = await parser_action.run()
|
||||
assert rsp
|
||||
assert parser_action.args
|
||||
assert parser_action.args.get("text") == "Draw an apple"
|
||||
assert parser_action.args.get("size_type") == "512x512"
|
||||
|
||||
action = SkillAction(skill=self.skill, args=parser_action.args)
|
||||
rsp = await action.run()
|
||||
assert rsp
|
||||
assert "image/png;base64," in rsp.content
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
@ -9,10 +9,10 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.actions.write_code import WriteCode
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.openai_api import OpenAIGPTAPI as LLM
|
||||
from metagpt.schema import CodingContext, Document
|
||||
from tests.metagpt.actions.mock import TASKS_2, WRITE_CODE_PROMPT_SAMPLE
|
||||
from tests.metagpt.actions.mock_markdown import TASKS_2, WRITE_CODE_PROMPT_SAMPLE
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
|
|||
58
tests/metagpt/actions/test_write_teaching_plan.py
Normal file
58
tests/metagpt/actions/test_write_teaching_plan.py
Normal file
|
|
@ -0,0 +1,58 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/7/28 17:25
|
||||
@Author : mashenquan
|
||||
@File : test_write_teaching_plan.py
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Optional
|
||||
|
||||
from langchain.llms.base import LLM
|
||||
from pydantic import BaseModel
|
||||
|
||||
from metagpt.actions.write_teaching_plan import WriteTeachingPlanPart
|
||||
from metagpt.config import Config
|
||||
from metagpt.schema import Message
|
||||
|
||||
|
||||
class MockWriteTeachingPlanPart(WriteTeachingPlanPart):
|
||||
def __init__(self, options, name: str = "", context=None, llm: LLM = None, topic="", language="Chinese"):
|
||||
super().__init__(options, name, context, llm, topic, language)
|
||||
|
||||
async def _aask(self, prompt: str, system_msgs: Optional[list[str]] = None) -> str:
|
||||
return f"{WriteTeachingPlanPart.DATA_BEGIN_TAG}\nprompt\n{WriteTeachingPlanPart.DATA_END_TAG}"
|
||||
|
||||
|
||||
async def mock_write_teaching_plan_part():
|
||||
class Inputs(BaseModel):
|
||||
input: str
|
||||
name: str
|
||||
topic: str
|
||||
language: str
|
||||
|
||||
inputs = [
|
||||
{"input": "AABBCC", "name": "A", "topic": WriteTeachingPlanPart.COURSE_TITLE, "language": "C"},
|
||||
{"input": "DDEEFFF", "name": "A1", "topic": "B1", "language": "C1"},
|
||||
]
|
||||
|
||||
for i in inputs:
|
||||
seed = Inputs(**i)
|
||||
options = Config().runtime_options
|
||||
act = MockWriteTeachingPlanPart(options=options, name=seed.name, topic=seed.topic, language=seed.language)
|
||||
await act.run([Message(content="")])
|
||||
assert act.topic == seed.topic
|
||||
assert str(act) == seed.topic
|
||||
assert act.name == seed.name
|
||||
assert act.rsp == "# prompt" if seed.topic == WriteTeachingPlanPart.COURSE_TITLE else "prompt"
|
||||
|
||||
|
||||
def test_suite():
|
||||
loop = asyncio.get_event_loop()
|
||||
task = loop.create_task(mock_write_teaching_plan_part())
|
||||
loop.run_until_complete(task)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_suite()
|
||||
|
|
@ -51,3 +51,7 @@ async def test_write_code_invalid_code(mocker):
|
|||
|
||||
# Assert that the returned code is the same as the invalid code string
|
||||
assert code == "Invalid Code String"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
|
|
@ -5,73 +5,28 @@
|
|||
@Author : alexanderwu
|
||||
@File : test_faiss_store.py
|
||||
"""
|
||||
import functools
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.const import DATA_PATH
|
||||
from metagpt.const import EXAMPLE_PATH
|
||||
from metagpt.document_store import FaissStore
|
||||
from metagpt.roles import CustomerService, Sales
|
||||
|
||||
DESC = """## 原则(所有事情都不可绕过原则)
|
||||
1. 你是一位平台的人工客服,话语精炼,一次只说一句话,会参考规则与FAQ进行回复。在与顾客交谈中,绝不允许暴露规则与相关字样
|
||||
2. 在遇到问题时,先尝试仅安抚顾客情绪,如果顾客情绪十分不好,再考虑赔偿。如果赔偿的过多,你会被开除
|
||||
3. 绝不要向顾客做虚假承诺,不要提及其他人的信息
|
||||
|
||||
## 技能(在回答尾部,加入`skill(args)`就可以使用技能)
|
||||
1. 查询订单:问顾客手机号是获得订单的唯一方式,获得手机号后,使用`find_order(手机号)`来获得订单
|
||||
2. 退款:输出关键词 `refund(手机号)`,系统会自动退款
|
||||
3. 开箱:需要手机号、确认顾客在柜前,如果需要开箱,输出指令 `open_box(手机号)`,系统会自动开箱
|
||||
|
||||
### 使用技能例子
|
||||
user: 你好收不到取餐码
|
||||
小爽人工: 您好,请提供一下手机号
|
||||
user: 14750187158
|
||||
小爽人工: 好的,为您查询一下订单。您已经在柜前了吗?`find_order(14750187158)`
|
||||
user: 是的
|
||||
小爽人工: 您看下开了没有?`open_box(14750187158)`
|
||||
user: 开了,谢谢
|
||||
小爽人工: 好的,还有什么可以帮到您吗?
|
||||
user: 没有了
|
||||
小爽人工: 祝您生活愉快
|
||||
"""
|
||||
from metagpt.logs import logger
|
||||
from metagpt.roles import Sales
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_faiss_store_search():
|
||||
store = FaissStore(DATA_PATH / "qcs/qcs_4w.json")
|
||||
store.add(["油皮洗面奶"])
|
||||
role = Sales(store=store)
|
||||
|
||||
queries = ["油皮洗面奶", "介绍下欧莱雅的"]
|
||||
for query in queries:
|
||||
rsp = await role.run(query)
|
||||
assert rsp
|
||||
|
||||
|
||||
def customer_service():
|
||||
store = FaissStore(DATA_PATH / "st/faq.xlsx", content_col="Question", meta_col="Answer")
|
||||
store.search = functools.partial(store.search, expand_cols=True)
|
||||
role = CustomerService(profile="小爽人工", desc=DESC, store=store)
|
||||
return role
|
||||
async def test_search_json():
|
||||
store = FaissStore(EXAMPLE_PATH / "example.json")
|
||||
role = Sales(profile="Sales", store=store)
|
||||
query = "Which facial cleanser is good for oily skin?"
|
||||
result = await role.run(query)
|
||||
logger.info(result)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_faiss_store_customer_service():
|
||||
allq = [
|
||||
# ["我的餐怎么两小时都没到", "退货吧"],
|
||||
[
|
||||
"你好收不到取餐码,麻烦帮我开箱",
|
||||
"14750187158",
|
||||
]
|
||||
]
|
||||
role = customer_service()
|
||||
for queries in allq:
|
||||
for query in queries:
|
||||
rsp = await role.run(query)
|
||||
assert rsp
|
||||
|
||||
|
||||
def test_faiss_store_no_file():
|
||||
with pytest.raises(FileNotFoundError):
|
||||
FaissStore(DATA_PATH / "wtf.json")
|
||||
async def test_search_xlsx():
|
||||
store = FaissStore(EXAMPLE_PATH / "example.xlsx")
|
||||
role = Sales(profile="Sales", store=store)
|
||||
query = "Which facial cleanser is good for oily skin?"
|
||||
result = await role.run(query)
|
||||
logger.info(result)
|
||||
|
|
|
|||
0
tests/metagpt/learn/__init__.py
Normal file
0
tests/metagpt/learn/__init__.py
Normal file
27
tests/metagpt/learn/test_google_search.py
Normal file
27
tests/metagpt/learn/test_google_search.py
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
import asyncio
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from metagpt.learn.google_search import google_search
|
||||
|
||||
|
||||
async def mock_google_search():
|
||||
class Input(BaseModel):
|
||||
input: str
|
||||
|
||||
inputs = [{"input": "ai agent"}]
|
||||
|
||||
for i in inputs:
|
||||
seed = Input(**i)
|
||||
result = await google_search(seed.input)
|
||||
assert result != ""
|
||||
|
||||
|
||||
def test_suite():
|
||||
loop = asyncio.get_event_loop()
|
||||
task = loop.create_task(mock_google_search())
|
||||
loop.run_until_complete(task)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_suite()
|
||||
43
tests/metagpt/learn/test_skill_loader.py
Normal file
43
tests/metagpt/learn/test_skill_loader.py
Normal file
|
|
@ -0,0 +1,43 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/9/19
|
||||
@Author : mashenquan
|
||||
@File : test_skill_loader.py
|
||||
@Desc : Unit tests.
|
||||
"""
|
||||
import pytest
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.learn.skill_loader import SkillsDeclaration
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_suite():
|
||||
CONFIG.agent_skills = [
|
||||
{"id": 1, "name": "text_to_speech", "type": "builtin", "config": {}, "enabled": True},
|
||||
{"id": 2, "name": "text_to_image", "type": "builtin", "config": {}, "enabled": True},
|
||||
{"id": 3, "name": "ai_call", "type": "builtin", "config": {}, "enabled": True},
|
||||
{"id": 3, "name": "data_analysis", "type": "builtin", "config": {}, "enabled": True},
|
||||
{"id": 5, "name": "crawler", "type": "builtin", "config": {"engine": "ddg"}, "enabled": True},
|
||||
{"id": 6, "name": "knowledge", "type": "builtin", "config": {}, "enabled": True},
|
||||
{"id": 6, "name": "web_search", "type": "builtin", "config": {}, "enabled": True},
|
||||
]
|
||||
loader = await SkillsDeclaration.load()
|
||||
skills = loader.get_skill_list()
|
||||
assert skills
|
||||
assert len(skills) >= 3
|
||||
for desc, name in skills.items():
|
||||
assert desc
|
||||
assert name
|
||||
|
||||
entity = loader.entities.get("Assistant")
|
||||
assert entity
|
||||
assert entity.skills
|
||||
for sk in entity.skills:
|
||||
assert sk
|
||||
assert sk.arguments
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
38
tests/metagpt/learn/test_text_to_embedding.py
Normal file
38
tests/metagpt/learn/test_text_to_embedding.py
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/8/18
|
||||
@Author : mashenquan
|
||||
@File : test_text_to_embedding.py
|
||||
@Desc : Unit tests.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from metagpt.learn.text_to_embedding import text_to_embedding
|
||||
from metagpt.tools.openai_text_to_embedding import ResultEmbedding
|
||||
|
||||
|
||||
async def mock_text_to_embedding():
|
||||
class Input(BaseModel):
|
||||
input: str
|
||||
|
||||
inputs = [{"input": "Panda emoji"}]
|
||||
|
||||
for i in inputs:
|
||||
seed = Input(**i)
|
||||
data = await text_to_embedding(seed.input)
|
||||
v = ResultEmbedding(**data)
|
||||
assert len(v.data) > 0
|
||||
|
||||
|
||||
def test_suite():
|
||||
loop = asyncio.get_event_loop()
|
||||
task = loop.create_task(mock_text_to_embedding())
|
||||
loop.run_until_complete(task)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_suite()
|
||||
42
tests/metagpt/learn/test_text_to_image.py
Normal file
42
tests/metagpt/learn/test_text_to_image.py
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/8/18
|
||||
@Author : mashenquan
|
||||
@File : test_text_to_image.py
|
||||
@Desc : Unit tests.
|
||||
"""
|
||||
|
||||
import base64
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from metagpt.learn.text_to_image import text_to_image
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test():
|
||||
class Input(BaseModel):
|
||||
input: str
|
||||
size_type: str
|
||||
|
||||
inputs = [{"input": "Panda emoji", "size_type": "512x512"}]
|
||||
|
||||
for i in inputs:
|
||||
seed = Input(**i)
|
||||
base64_data = await text_to_image(seed.input)
|
||||
assert base64_data != ""
|
||||
print(f"{seed.input} -> {base64_data}")
|
||||
flags = ";base64,"
|
||||
assert flags in base64_data
|
||||
ix = base64_data.find(flags) + len(flags)
|
||||
declaration = base64_data[0:ix]
|
||||
assert declaration
|
||||
data = base64_data[ix:]
|
||||
assert data
|
||||
assert base64.b64decode(data, validate=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
45
tests/metagpt/learn/test_text_to_speech.py
Normal file
45
tests/metagpt/learn/test_text_to_speech.py
Normal file
|
|
@ -0,0 +1,45 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/8/18
|
||||
@Author : mashenquan
|
||||
@File : test_text_to_speech.py
|
||||
@Desc : Unit tests.
|
||||
"""
|
||||
import asyncio
|
||||
import base64
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from metagpt.learn.text_to_speech import text_to_speech
|
||||
|
||||
|
||||
async def mock_text_to_speech():
|
||||
class Input(BaseModel):
|
||||
input: str
|
||||
|
||||
inputs = [{"input": "Panda emoji"}]
|
||||
|
||||
for i in inputs:
|
||||
seed = Input(**i)
|
||||
base64_data = await text_to_speech(seed.input)
|
||||
assert base64_data != ""
|
||||
print(f"{seed.input} -> {base64_data}")
|
||||
flags = ";base64,"
|
||||
assert flags in base64_data
|
||||
ix = base64_data.find(flags) + len(flags)
|
||||
declaration = base64_data[0:ix]
|
||||
assert declaration
|
||||
data = base64_data[ix:]
|
||||
assert data
|
||||
assert base64.b64decode(data, validate=True)
|
||||
|
||||
|
||||
def test_suite():
|
||||
loop = asyncio.get_event_loop()
|
||||
task = loop.create_task(mock_text_to_speech())
|
||||
loop.run_until_complete(task)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_suite()
|
||||
51
tests/metagpt/memory/test_brain_memory.py
Normal file
51
tests/metagpt/memory/test_brain_memory.py
Normal file
|
|
@ -0,0 +1,51 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/8/27
|
||||
@Author : mashenquan
|
||||
@File : test_brain_memory.py
|
||||
"""
|
||||
# import json
|
||||
# from typing import List
|
||||
#
|
||||
# import pydantic
|
||||
#
|
||||
# from metagpt.memory.brain_memory import BrainMemory
|
||||
# from metagpt.schema import Message
|
||||
#
|
||||
#
|
||||
# def test_json():
|
||||
# class Input(pydantic.BaseModel):
|
||||
# history: List[str]
|
||||
# solution: List[str]
|
||||
# knowledge: List[str]
|
||||
# stack: List[str]
|
||||
#
|
||||
# inputs = [{"history": ["a", "b"], "solution": ["c"], "knowledge": ["d", "e"], "stack": ["f"]}]
|
||||
#
|
||||
# for i in inputs:
|
||||
# v = Input(**i)
|
||||
# bm = BrainMemory()
|
||||
# for h in v.history:
|
||||
# msg = Message(content=h)
|
||||
# bm.history.append(msg.dict())
|
||||
# for h in v.solution:
|
||||
# msg = Message(content=h)
|
||||
# bm.solution.append(msg.dict())
|
||||
# for h in v.knowledge:
|
||||
# msg = Message(content=h)
|
||||
# bm.knowledge.append(msg.dict())
|
||||
# for h in v.stack:
|
||||
# msg = Message(content=h)
|
||||
# bm.stack.append(msg.dict())
|
||||
# s = bm.json()
|
||||
# m = json.loads(s)
|
||||
# bm = BrainMemory(**m)
|
||||
# assert bm
|
||||
# for v in bm.history:
|
||||
# msg = Message(**v)
|
||||
# assert msg
|
||||
#
|
||||
#
|
||||
# if __name__ == "__main__":
|
||||
# test_json()
|
||||
|
|
@ -2,22 +2,25 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Desc : unittest of `metagpt/memory/longterm_memory.py`
|
||||
@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from metagpt.actions import UserRequirement
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.memory import LongTermMemory
|
||||
from metagpt.memory.longterm_memory import LongTermMemory
|
||||
from metagpt.roles.role import RoleContext
|
||||
from metagpt.schema import Message
|
||||
|
||||
|
||||
def test_ltm_search():
|
||||
assert hasattr(CONFIG, "long_term_memory") is True
|
||||
openai_api_key = CONFIG.openai_api_key
|
||||
assert len(openai_api_key) > 20
|
||||
os.environ.setdefault("OPENAI_API_KEY", CONFIG.openai_api_key)
|
||||
assert len(CONFIG.openai_api_key) > 20
|
||||
|
||||
role_id = "UTUserLtm(Product Manager)"
|
||||
rc = RoleContext(watch=[UserRequirement])
|
||||
rc = RoleContext(watch={"metagpt.actions.add_requirement.UserRequirement"})
|
||||
ltm = LongTermMemory()
|
||||
ltm.recover_memory(role_id, rc)
|
||||
|
||||
|
|
@ -28,6 +31,7 @@ def test_ltm_search():
|
|||
ltm.add(message)
|
||||
|
||||
sim_idea = "Write a game of cli snake"
|
||||
|
||||
sim_message = Message(role="User", content=sim_idea, cause_by=UserRequirement)
|
||||
news = ltm.find_news([sim_message])
|
||||
assert len(news) == 0
|
||||
|
|
|
|||
57
tests/metagpt/memory/test_memory.py
Normal file
57
tests/metagpt/memory/test_memory.py
Normal file
|
|
@ -0,0 +1,57 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the unittest of Memory
|
||||
|
||||
from metagpt.actions import UserRequirement
|
||||
from metagpt.memory.memory import Memory
|
||||
from metagpt.schema import Message
|
||||
|
||||
|
||||
def test_memory():
|
||||
memory = Memory()
|
||||
|
||||
message1 = Message(content="test message1", role="user1")
|
||||
message2 = Message(content="test message2", role="user2")
|
||||
message3 = Message(content="test message3", role="user1")
|
||||
memory.add(message1)
|
||||
assert memory.count() == 1
|
||||
|
||||
memory.delete_newest()
|
||||
assert memory.count() == 0
|
||||
|
||||
memory.add_batch([message1, message2])
|
||||
assert memory.count() == 2
|
||||
assert len(memory.index.get(message1.cause_by)) == 2
|
||||
|
||||
messages = memory.get_by_role("user1")
|
||||
assert messages[0].content == message1.content
|
||||
|
||||
messages = memory.get_by_content("test message")
|
||||
assert len(messages) == 2
|
||||
|
||||
messages = memory.get_by_action(UserRequirement)
|
||||
assert len(messages) == 2
|
||||
|
||||
messages = memory.get_by_actions([UserRequirement])
|
||||
assert len(messages) == 2
|
||||
|
||||
messages = memory.try_remember("test message")
|
||||
assert len(messages) == 2
|
||||
|
||||
messages = memory.get(k=1)
|
||||
assert len(messages) == 1
|
||||
|
||||
messages = memory.get(k=5)
|
||||
assert len(messages) == 2
|
||||
|
||||
messages = memory.find_news([message3])
|
||||
assert len(messages) == 1
|
||||
|
||||
memory.delete(message1)
|
||||
assert memory.count() == 1
|
||||
messages = memory.get_by_role("user2")
|
||||
assert messages[0].content == message2.content
|
||||
|
||||
memory.clear()
|
||||
assert memory.count() == 0
|
||||
assert len(memory.index) == 0
|
||||
|
|
@ -4,20 +4,28 @@
|
|||
@Desc : the unittests of metagpt/memory/memory_storage.py
|
||||
"""
|
||||
|
||||
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from metagpt.actions import UserRequirement, WritePRD
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import DATA_PATH
|
||||
from metagpt.memory.memory_storage import MemoryStorage
|
||||
from metagpt.schema import Message
|
||||
|
||||
os.environ.setdefault("OPENAI_API_KEY", CONFIG.openai_api_key)
|
||||
|
||||
|
||||
def test_idea_message():
|
||||
idea = "Write a cli snake game"
|
||||
role_id = "UTUser1(Product Manager)"
|
||||
message = Message(role="User", content=idea, cause_by=UserRequirement)
|
||||
|
||||
shutil.rmtree(Path(DATA_PATH / f"role_mem/{role_id}/"))
|
||||
|
||||
memory_storage: MemoryStorage = MemoryStorage()
|
||||
messages = memory_storage.recover_memory(role_id)
|
||||
assert len(messages) == 0
|
||||
|
|
@ -27,12 +35,12 @@ def test_idea_message():
|
|||
|
||||
sim_idea = "Write a game of cli snake"
|
||||
sim_message = Message(role="User", content=sim_idea, cause_by=UserRequirement)
|
||||
new_messages = memory_storage.search(sim_message)
|
||||
new_messages = memory_storage.search_dissimilar(sim_message)
|
||||
assert len(new_messages) == 0 # similar, return []
|
||||
|
||||
new_idea = "Write a 2048 web game"
|
||||
new_message = Message(role="User", content=new_idea, cause_by=UserRequirement)
|
||||
new_messages = memory_storage.search(new_message)
|
||||
new_messages = memory_storage.search_dissimilar(new_message)
|
||||
assert new_messages[0].content == message.content
|
||||
|
||||
memory_storage.clean()
|
||||
|
|
@ -50,6 +58,8 @@ def test_actionout_message():
|
|||
content=content, instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD
|
||||
) # WritePRD as test action
|
||||
|
||||
shutil.rmtree(Path(DATA_PATH / f"role_mem/{role_id}/"))
|
||||
|
||||
memory_storage: MemoryStorage = MemoryStorage()
|
||||
messages = memory_storage.recover_memory(role_id)
|
||||
assert len(messages) == 0
|
||||
|
|
@ -59,12 +69,12 @@ def test_actionout_message():
|
|||
|
||||
sim_conent = "The request is command-line interface (CLI) snake game"
|
||||
sim_message = Message(content=sim_conent, instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD)
|
||||
new_messages = memory_storage.search(sim_message)
|
||||
new_messages = memory_storage.search_dissimilar(sim_message)
|
||||
assert len(new_messages) == 0 # similar, return []
|
||||
|
||||
new_conent = "Incorporate basic features of a snake game such as scoring and increasing difficulty"
|
||||
new_message = Message(content=new_conent, instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD)
|
||||
new_messages = memory_storage.search(new_message)
|
||||
new_messages = memory_storage.search_dissimilar(new_message)
|
||||
assert new_messages[0].content == message.content
|
||||
|
||||
memory_storage.clean()
|
||||
|
|
|
|||
29
tests/metagpt/provider/test_anthropic_api.py
Normal file
29
tests/metagpt/provider/test_anthropic_api.py
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the unittest of Claude2
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.provider.anthropic_api import Claude2
|
||||
|
||||
prompt = "who are you"
|
||||
resp = "I'am Claude2"
|
||||
|
||||
|
||||
def mock_llm_ask(self, msg: str) -> str:
|
||||
return resp
|
||||
|
||||
|
||||
async def mock_llm_aask(self, msg: str) -> str:
|
||||
return resp
|
||||
|
||||
|
||||
def test_claude2_ask(mocker):
|
||||
mocker.patch("metagpt.provider.anthropic_api.Claude2.ask", mock_llm_ask)
|
||||
assert resp == Claude2().ask(prompt)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_claude2_aask(mocker):
|
||||
mocker.patch("metagpt.provider.anthropic_api.Claude2.aask", mock_llm_aask)
|
||||
assert resp == await Claude2().aask(prompt)
|
||||
|
|
@ -6,10 +6,106 @@
|
|||
@File : test_base_gpt_api.py
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.schema import Message
|
||||
|
||||
default_chat_resp = {
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "I'am GPT",
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
]
|
||||
}
|
||||
prompt_msg = "who are you"
|
||||
resp_content = default_chat_resp["choices"][0]["message"]["content"]
|
||||
|
||||
def test_message():
|
||||
message = Message(role="user", content="wtf")
|
||||
|
||||
class MockBaseGPTAPI(BaseGPTAPI):
|
||||
def completion(self, messages: list[dict], timeout=3):
|
||||
return default_chat_resp
|
||||
|
||||
async def acompletion(self, messages: list[dict], timeout=3):
|
||||
return default_chat_resp
|
||||
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, generator: bool = False, timeout=3) -> str:
|
||||
return resp_content
|
||||
|
||||
async def close(self):
|
||||
return default_chat_resp
|
||||
|
||||
|
||||
def test_base_gpt_api():
|
||||
message = Message(role="user", content="hello")
|
||||
assert "role" in message.to_dict()
|
||||
assert "user" in str(message)
|
||||
|
||||
base_gpt_api = MockBaseGPTAPI()
|
||||
msg_prompt = base_gpt_api.messages_to_prompt([message])
|
||||
assert msg_prompt == "user: hello"
|
||||
|
||||
msg_dict = base_gpt_api.messages_to_dict([message])
|
||||
assert msg_dict == [{"role": "user", "content": "hello"}]
|
||||
|
||||
openai_funccall_resp = {
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "test",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_Y5r6Ddr2Qc2ZrqgfwzPX5l72",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "execute",
|
||||
"arguments": '{\n "language": "python",\n "code": "print(\'Hello, World!\')"\n}',
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
]
|
||||
}
|
||||
func: dict = base_gpt_api.get_choice_function(openai_funccall_resp)
|
||||
assert func == {
|
||||
"name": "execute",
|
||||
"arguments": '{\n "language": "python",\n "code": "print(\'Hello, World!\')"\n}',
|
||||
}
|
||||
|
||||
func_args: dict = base_gpt_api.get_choice_function_arguments(openai_funccall_resp)
|
||||
assert func_args == {"language": "python", "code": "print('Hello, World!')"}
|
||||
|
||||
choice_text = base_gpt_api.get_choice_text(openai_funccall_resp)
|
||||
assert choice_text == openai_funccall_resp["choices"][0]["message"]["content"]
|
||||
|
||||
resp = base_gpt_api.ask(prompt_msg)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = base_gpt_api.ask_batch([prompt_msg])
|
||||
assert resp == resp_content
|
||||
|
||||
resp = base_gpt_api.ask_code([prompt_msg])
|
||||
assert resp == resp_content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_base_gpt_api():
|
||||
base_gpt_api = MockBaseGPTAPI()
|
||||
|
||||
resp = await base_gpt_api.aask(prompt_msg)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await base_gpt_api.aask_batch([prompt_msg])
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await base_gpt_api.aask_code([prompt_msg])
|
||||
assert resp == resp_content
|
||||
|
|
|
|||
91
tests/metagpt/provider/test_fireworks_api.py
Normal file
91
tests/metagpt/provider/test_fireworks_api.py
Normal file
|
|
@ -0,0 +1,91 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the unittest of fireworks api
|
||||
|
||||
import pytest
|
||||
from openai.types.chat.chat_completion import (
|
||||
ChatCompletion,
|
||||
ChatCompletionMessage,
|
||||
Choice,
|
||||
)
|
||||
from openai.types.completion_usage import CompletionUsage
|
||||
|
||||
from metagpt.provider.fireworks_api import (
|
||||
MODEL_GRADE_TOKEN_COSTS,
|
||||
FireworksCostManager,
|
||||
FireWorksGPTAPI,
|
||||
)
|
||||
|
||||
resp_content = "I'm fireworks"
|
||||
default_resp = ChatCompletion(
|
||||
id="cmpl-a6652c1bb181caae8dd19ad8",
|
||||
model="accounts/fireworks/models/llama-v2-13b-chat",
|
||||
object="chat.completion",
|
||||
created=1703300855,
|
||||
choices=[
|
||||
Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(role="assistant", content=resp_content))
|
||||
],
|
||||
usage=CompletionUsage(completion_tokens=110, prompt_tokens=92, total_tokens=202),
|
||||
)
|
||||
|
||||
prompt_msg = "who are you"
|
||||
messages = [{"role": "user", "content": prompt_msg}]
|
||||
|
||||
|
||||
def test_fireworks_costmanager():
|
||||
cost_manager = FireworksCostManager()
|
||||
assert MODEL_GRADE_TOKEN_COSTS["-1"] == cost_manager.model_grade_token_costs("test")
|
||||
assert MODEL_GRADE_TOKEN_COSTS["-1"] == cost_manager.model_grade_token_costs("xxx-81b-chat")
|
||||
assert MODEL_GRADE_TOKEN_COSTS["16"] == cost_manager.model_grade_token_costs("llama-v2-13b-chat")
|
||||
assert MODEL_GRADE_TOKEN_COSTS["16"] == cost_manager.model_grade_token_costs("xxx-15.5b-chat")
|
||||
assert MODEL_GRADE_TOKEN_COSTS["16"] == cost_manager.model_grade_token_costs("xxx-16b-chat")
|
||||
assert MODEL_GRADE_TOKEN_COSTS["80"] == cost_manager.model_grade_token_costs("xxx-80b-chat")
|
||||
assert MODEL_GRADE_TOKEN_COSTS["mixtral-8x7b"] == cost_manager.model_grade_token_costs("mixtral-8x7b-chat")
|
||||
|
||||
|
||||
def mock_llm_completion(self, messages: list[dict], timeout: int = 60) -> ChatCompletion:
|
||||
return default_resp
|
||||
|
||||
|
||||
async def mock_llm_acompletion(self, messgaes: list[dict], stream: bool = False, timeout: int = 60) -> ChatCompletion:
|
||||
return default_resp
|
||||
|
||||
|
||||
async def mock_llm_achat_completion_stream(self, messgaes: list[dict]) -> str:
|
||||
return default_resp.choices[0].message.content
|
||||
|
||||
|
||||
def test_fireworks_completion(mocker):
|
||||
mocker.patch("metagpt.provider.fireworks_api.FireWorksGPTAPI.completion", mock_llm_completion)
|
||||
fireworks_gpt = FireWorksGPTAPI()
|
||||
|
||||
resp = fireworks_gpt.completion(messages)
|
||||
assert resp.choices[0].message.content == resp_content
|
||||
|
||||
resp = fireworks_gpt.ask(prompt_msg)
|
||||
assert resp == resp_content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fireworks_acompletion(mocker):
|
||||
mocker.patch("metagpt.provider.fireworks_api.FireWorksGPTAPI.acompletion", mock_llm_acompletion)
|
||||
mocker.patch("metagpt.provider.fireworks_api.FireWorksGPTAPI._achat_completion", mock_llm_acompletion)
|
||||
mocker.patch(
|
||||
"metagpt.provider.fireworks_api.FireWorksGPTAPI._achat_completion_stream", mock_llm_achat_completion_stream
|
||||
)
|
||||
fireworks_gpt = FireWorksGPTAPI()
|
||||
|
||||
resp = await fireworks_gpt.acompletion(messages, stream=False)
|
||||
assert resp.choices[0].message.content in resp_content
|
||||
|
||||
resp = await fireworks_gpt.aask(prompt_msg, stream=False)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await fireworks_gpt.acompletion_text(messages, stream=False)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await fireworks_gpt.acompletion_text(messages, stream=True)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await fireworks_gpt.aask(prompt_msg)
|
||||
assert resp == resp_content
|
||||
20
tests/metagpt/provider/test_general_api_requestor.py
Normal file
20
tests/metagpt/provider/test_general_api_requestor.py
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the unittest of APIRequestor
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.provider.general_api_requestor import GeneralAPIRequestor
|
||||
|
||||
api_requestor = GeneralAPIRequestor(base_url="http://www.baidu.com")
|
||||
|
||||
|
||||
def test_api_requestor():
|
||||
resp, _, _ = api_requestor.request(method="get", url="/s?wd=baidu")
|
||||
assert b"baidu" in resp
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_api_requestor():
|
||||
resp, _, _ = await api_requestor.arequest(method="get", url="/s?wd=baidu")
|
||||
assert b"baidu" in resp
|
||||
|
|
@ -9,33 +9,62 @@ import pytest
|
|||
|
||||
from metagpt.provider.google_gemini_api import GeminiGPTAPI
|
||||
|
||||
messages = [{"role": "user", "parts": "who are you"}]
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockGeminiResponse(ABC):
|
||||
text: str
|
||||
|
||||
|
||||
default_resp = MockGeminiResponse(text="I'm gemini from google")
|
||||
prompt_msg = "who are you"
|
||||
messages = [{"role": "user", "parts": prompt_msg}]
|
||||
resp_content = "I'm gemini from google"
|
||||
default_resp = MockGeminiResponse(text=resp_content)
|
||||
|
||||
|
||||
def mock_llm_ask(self, messages: list[dict]) -> MockGeminiResponse:
|
||||
def mock_llm_completion(self, messages: list[dict], timeout: int = 60) -> MockGeminiResponse:
|
||||
return default_resp
|
||||
|
||||
|
||||
async def mock_llm_acompletion(
|
||||
self, messgaes: list[dict], stream: bool = False, timeout: int = 60
|
||||
) -> MockGeminiResponse:
|
||||
return default_resp
|
||||
|
||||
|
||||
async def mock_llm_achat_completion_stream(self, messgaes: list[dict]) -> str:
|
||||
return resp_content
|
||||
|
||||
|
||||
def test_gemini_completion(mocker):
|
||||
mocker.patch("metagpt.provider.google_gemini_api.GeminiGPTAPI.completion", mock_llm_ask)
|
||||
resp = GeminiGPTAPI().completion(messages)
|
||||
assert resp.text == default_resp.text
|
||||
mocker.patch("metagpt.provider.google_gemini_api.GeminiGPTAPI.completion", mock_llm_completion)
|
||||
gemini_gpt = GeminiGPTAPI()
|
||||
resp = gemini_gpt.completion(messages)
|
||||
assert resp.text == resp_content
|
||||
|
||||
|
||||
async def mock_llm_aask(self, messgaes: list[dict]) -> MockGeminiResponse:
|
||||
return default_resp
|
||||
resp = gemini_gpt.ask(prompt_msg)
|
||||
assert resp == resp_content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gemini_acompletion(mocker):
|
||||
mocker.patch("metagpt.provider.google_gemini_api.GeminiGPTAPI.acompletion", mock_llm_aask)
|
||||
resp = await GeminiGPTAPI().acompletion(messages)
|
||||
mocker.patch("metagpt.provider.google_gemini_api.GeminiGPTAPI.acompletion", mock_llm_acompletion)
|
||||
mocker.patch("metagpt.provider.google_gemini_api.GeminiGPTAPI._achat_completion", mock_llm_acompletion)
|
||||
mocker.patch(
|
||||
"metagpt.provider.google_gemini_api.GeminiGPTAPI._achat_completion_stream", mock_llm_achat_completion_stream
|
||||
)
|
||||
gemini_gpt = GeminiGPTAPI()
|
||||
|
||||
resp = await gemini_gpt.acompletion(messages)
|
||||
assert resp.text == default_resp.text
|
||||
|
||||
resp = await gemini_gpt.aask(prompt_msg, stream=False)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await gemini_gpt.acompletion_text(messages, stream=False)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await gemini_gpt.acompletion_text(messages, stream=True)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await gemini_gpt.aask(prompt_msg)
|
||||
assert resp == resp_content
|
||||
|
|
|
|||
38
tests/metagpt/provider/test_human_provider.py
Normal file
38
tests/metagpt/provider/test_human_provider.py
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the unittest of HumanProvider
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.provider.human_provider import HumanProvider
|
||||
|
||||
resp_content = "test"
|
||||
|
||||
|
||||
def mock_llm_ask(msg: str, timeout: int = 3) -> str:
|
||||
return resp_content
|
||||
|
||||
|
||||
async def mock_llm_aask(msg: str, timeout: int = 3) -> str:
|
||||
return mock_llm_ask(msg)
|
||||
|
||||
|
||||
def test_human_provider(mocker):
|
||||
mocker.patch("metagpt.provider.human_provider.HumanProvider.ask", mock_llm_ask)
|
||||
human_provider = HumanProvider()
|
||||
|
||||
assert resp_content == human_provider.ask(None)
|
||||
|
||||
assert not human_provider.completion(messages=[])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_human_provider(mocker):
|
||||
mocker.patch("metagpt.provider.human_provider.HumanProvider.aask", mock_llm_aask)
|
||||
human_provider = HumanProvider()
|
||||
|
||||
resp = await human_provider.aask(None)
|
||||
assert resp_content == resp
|
||||
|
||||
resp = await human_provider.acompletion([])
|
||||
assert not resp
|
||||
17
tests/metagpt/provider/test_metagpt_llm_api.py
Normal file
17
tests/metagpt/provider/test_metagpt_llm_api.py
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/8/30
|
||||
@Author : mashenquan
|
||||
@File : test_metagpt_llm_api.py
|
||||
"""
|
||||
from metagpt.provider.metagpt_api import MetaGPTAPI
|
||||
|
||||
|
||||
def test_metagpt():
|
||||
llm = MetaGPTAPI()
|
||||
assert llm
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_metagpt()
|
||||
|
|
@ -4,30 +4,58 @@
|
|||
|
||||
import pytest
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.provider.ollama_api import OllamaGPTAPI
|
||||
|
||||
messages = [{"role": "user", "content": "who are you"}]
|
||||
prompt_msg = "who are you"
|
||||
messages = [{"role": "user", "content": prompt_msg}]
|
||||
|
||||
resp_content = "I'm ollama"
|
||||
default_resp = {"message": {"role": "assistant", "content": resp_content}}
|
||||
|
||||
CONFIG.ollama_api_base = "http://xxx"
|
||||
|
||||
|
||||
default_resp = {"message": {"role": "assisant", "content": "I'm ollama"}}
|
||||
|
||||
|
||||
def mock_llm_ask(self, messages: list[dict]) -> dict:
|
||||
def mock_llm_completion(self, messages: list[dict], timeout: int = 60) -> dict:
|
||||
return default_resp
|
||||
|
||||
|
||||
async def mock_llm_acompletion(self, messgaes: list[dict], stream: bool = False, timeout: int = 60) -> dict:
|
||||
return default_resp
|
||||
|
||||
|
||||
async def mock_llm_achat_completion_stream(self, messgaes: list[dict]) -> str:
|
||||
return resp_content
|
||||
|
||||
|
||||
def test_gemini_completion(mocker):
|
||||
mocker.patch("metagpt.provider.ollama_api.OllamaGPTAPI.completion", mock_llm_ask)
|
||||
resp = OllamaGPTAPI().completion(messages)
|
||||
mocker.patch("metagpt.provider.ollama_api.OllamaGPTAPI.completion", mock_llm_completion)
|
||||
ollama_gpt = OllamaGPTAPI()
|
||||
resp = ollama_gpt.completion(messages)
|
||||
assert resp["message"]["content"] == default_resp["message"]["content"]
|
||||
|
||||
|
||||
async def mock_llm_aask(self, messgaes: list[dict]) -> dict:
|
||||
return default_resp
|
||||
resp = ollama_gpt.ask(prompt_msg)
|
||||
assert resp == resp_content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gemini_acompletion(mocker):
|
||||
mocker.patch("metagpt.provider.ollama_api.OllamaGPTAPI.acompletion", mock_llm_aask)
|
||||
resp = await OllamaGPTAPI().acompletion(messages)
|
||||
mocker.patch("metagpt.provider.ollama_api.OllamaGPTAPI.acompletion", mock_llm_acompletion)
|
||||
mocker.patch("metagpt.provider.ollama_api.OllamaGPTAPI._achat_completion", mock_llm_acompletion)
|
||||
mocker.patch("metagpt.provider.ollama_api.OllamaGPTAPI._achat_completion_stream", mock_llm_achat_completion_stream)
|
||||
ollama_gpt = OllamaGPTAPI()
|
||||
|
||||
resp = await ollama_gpt.acompletion(messages)
|
||||
assert resp["message"]["content"] == default_resp["message"]["content"]
|
||||
|
||||
resp = await ollama_gpt.aask(prompt_msg, stream=False)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await ollama_gpt.acompletion_text(messages, stream=False)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await ollama_gpt.acompletion_text(messages, stream=True)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await ollama_gpt.aask(prompt_msg)
|
||||
assert resp == resp_content
|
||||
|
|
|
|||
|
|
@ -85,14 +85,23 @@ def test_ask_code_list_str():
|
|||
class TestOpenAI:
|
||||
@pytest.fixture
|
||||
def config(self):
|
||||
return Mock(openai_api_key="test_key", openai_base_url="test_url", openai_proxy=None, openai_api_type="other")
|
||||
return Mock(
|
||||
openai_api_key="test_key",
|
||||
OPENAI_API_KEY="test_key",
|
||||
openai_base_url="test_url",
|
||||
OPENAI_BASE_URL="test_url",
|
||||
openai_proxy=None,
|
||||
openai_api_type="other",
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def config_azure(self):
|
||||
return Mock(
|
||||
openai_api_key="test_key",
|
||||
OPENAI_API_KEY="test_key",
|
||||
openai_api_version="test_version",
|
||||
openai_base_url="test_url",
|
||||
OPENAI_BASE_URL="test_url",
|
||||
openai_proxy=None,
|
||||
openai_api_type="azure",
|
||||
)
|
||||
|
|
@ -101,7 +110,9 @@ class TestOpenAI:
|
|||
def config_proxy(self):
|
||||
return Mock(
|
||||
openai_api_key="test_key",
|
||||
OPENAI_API_KEY="test_key",
|
||||
openai_base_url="test_url",
|
||||
OPENAI_BASE_URL="test_url",
|
||||
openai_proxy="http://proxy.com",
|
||||
openai_api_type="other",
|
||||
)
|
||||
|
|
@ -110,8 +121,10 @@ class TestOpenAI:
|
|||
def config_azure_proxy(self):
|
||||
return Mock(
|
||||
openai_api_key="test_key",
|
||||
OPENAI_API_KEY="test_key",
|
||||
openai_api_version="test_version",
|
||||
openai_base_url="test_url",
|
||||
OPENAI_BASE_URL="test_url",
|
||||
openai_proxy="http://proxy.com",
|
||||
openai_api_type="azure",
|
||||
)
|
||||
|
|
@ -129,8 +142,8 @@ class TestOpenAI:
|
|||
instance = OpenAIGPTAPI()
|
||||
instance.config = config_azure
|
||||
kwargs, async_kwargs = instance._make_client_kwargs()
|
||||
assert kwargs == {"api_key": "test_key", "api_version": "test_version", "azure_endpoint": "test_url"}
|
||||
assert async_kwargs == {"api_key": "test_key", "api_version": "test_version", "azure_endpoint": "test_url"}
|
||||
assert kwargs == {"api_key": "test_key", "base_url": "test_url"}
|
||||
assert async_kwargs == {"api_key": "test_key", "base_url": "test_url"}
|
||||
assert "http_client" not in kwargs
|
||||
assert "http_client" not in async_kwargs
|
||||
|
||||
|
|
|
|||
|
|
@ -1,11 +1,51 @@
|
|||
from metagpt.logs import logger
|
||||
from metagpt.provider.spark_api import SparkAPI
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the unittest of spark api
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.provider.spark_api import SparkGPTAPI
|
||||
|
||||
prompt_msg = "who are you"
|
||||
resp_content = "I'm Spark"
|
||||
|
||||
|
||||
def test_message():
|
||||
llm = SparkAPI()
|
||||
def mock_llm_completion(self, messages: list[dict], timeout: int = 60) -> str:
|
||||
return resp_content
|
||||
|
||||
logger.info(llm.ask('只回答"收到了"这三个字。'))
|
||||
result = llm.ask("写一篇五百字的日记")
|
||||
logger.info(result)
|
||||
assert len(result) > 100
|
||||
|
||||
async def mock_llm_acompletion(self, messgaes: list[dict], stream: bool = False, timeout: int = 60) -> str:
|
||||
return resp_content
|
||||
|
||||
|
||||
def test_spark_completion(mocker):
|
||||
mocker.patch("metagpt.provider.spark_api.SparkGPTAPI.completion", mock_llm_completion)
|
||||
spark_gpt = SparkGPTAPI()
|
||||
|
||||
resp = spark_gpt.completion([])
|
||||
assert resp == resp_content
|
||||
|
||||
resp = spark_gpt.ask(prompt_msg)
|
||||
assert resp == resp_content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_spark_acompletion(mocker):
|
||||
mocker.patch("metagpt.provider.spark_api.SparkGPTAPI.acompletion", mock_llm_acompletion)
|
||||
mocker.patch("metagpt.provider.spark_api.SparkGPTAPI.acompletion_text", mock_llm_acompletion)
|
||||
spark_gpt = SparkGPTAPI()
|
||||
|
||||
resp = await spark_gpt.acompletion([], stream=False)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await spark_gpt.aask(prompt_msg, stream=False)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await spark_gpt.acompletion_text([], stream=False)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await spark_gpt.acompletion_text([], stream=True)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await spark_gpt.aask(prompt_msg)
|
||||
assert resp == resp_content
|
||||
|
|
|
|||
|
|
@ -4,34 +4,62 @@
|
|||
|
||||
import pytest
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.provider.zhipuai_api import ZhiPuAIGPTAPI
|
||||
|
||||
default_resp = {"code": 200, "data": {"choices": [{"role": "assistant", "content": "I'm chatglm-turbo"}]}}
|
||||
CONFIG.zhipuai_api_key = "xxx"
|
||||
|
||||
messages = [{"role": "user", "content": "who are you"}]
|
||||
prompt_msg = "who are you"
|
||||
messages = [{"role": "user", "content": prompt_msg}]
|
||||
|
||||
resp_content = "I'm chatglm-turbo"
|
||||
default_resp = {"code": 200, "data": {"choices": [{"role": "assistant", "content": resp_content}]}}
|
||||
|
||||
|
||||
def mock_llm_ask(self, messages: list[dict]) -> dict:
|
||||
def mock_llm_completion(self, messages: list[dict], timeout: int = 60) -> dict:
|
||||
return default_resp
|
||||
|
||||
|
||||
async def mock_llm_acompletion(self, messgaes: list[dict], stream: bool = False, timeout: int = 60) -> dict:
|
||||
return default_resp
|
||||
|
||||
|
||||
async def mock_llm_achat_completion_stream(self, messgaes: list[dict]) -> str:
|
||||
return resp_content
|
||||
|
||||
|
||||
def test_zhipuai_completion(mocker):
|
||||
mocker.patch("metagpt.provider.zhipuai_api.ZhiPuAIGPTAPI.completion", mock_llm_ask)
|
||||
mocker.patch("metagpt.provider.zhipuai_api.ZhiPuAIGPTAPI.completion", mock_llm_completion)
|
||||
zhipu_gpt = ZhiPuAIGPTAPI()
|
||||
|
||||
resp = ZhiPuAIGPTAPI().completion(messages)
|
||||
resp = zhipu_gpt.completion(messages)
|
||||
assert resp["code"] == 200
|
||||
assert "chatglm-turbo" in resp["data"]["choices"][0]["content"]
|
||||
assert resp["data"]["choices"][0]["content"] == resp_content
|
||||
|
||||
|
||||
async def mock_llm_aask(self, messgaes: list[dict], stream: bool = False) -> dict:
|
||||
return default_resp
|
||||
resp = zhipu_gpt.ask(prompt_msg)
|
||||
assert resp == resp_content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_zhipuai_acompletion(mocker):
|
||||
mocker.patch("metagpt.provider.zhipuai_api.ZhiPuAIGPTAPI.acompletion_text", mock_llm_aask)
|
||||
mocker.patch("metagpt.provider.zhipuai_api.ZhiPuAIGPTAPI.acompletion", mock_llm_acompletion)
|
||||
mocker.patch("metagpt.provider.zhipuai_api.ZhiPuAIGPTAPI._achat_completion", mock_llm_acompletion)
|
||||
mocker.patch(
|
||||
"metagpt.provider.zhipuai_api.ZhiPuAIGPTAPI._achat_completion_stream", mock_llm_achat_completion_stream
|
||||
)
|
||||
zhipu_gpt = ZhiPuAIGPTAPI()
|
||||
|
||||
resp = await ZhiPuAIGPTAPI().acompletion_text(messages, stream=False)
|
||||
resp = await zhipu_gpt.acompletion(messages)
|
||||
assert resp["data"]["choices"][0]["content"] == resp_content
|
||||
|
||||
assert resp["code"] == 200
|
||||
assert "chatglm-turbo" in resp["data"]["choices"][0]["content"]
|
||||
resp = await zhipu_gpt.aask(prompt_msg, stream=False)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await zhipu_gpt.acompletion_text(messages, stream=False)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await zhipu_gpt.acompletion_text(messages, stream=True)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await zhipu_gpt.aask(prompt_msg)
|
||||
assert resp == resp_content
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
"""
|
||||
@Time : 2023/5/12 13:05
|
||||
@Author : alexanderwu
|
||||
@File : mock.py
|
||||
@File : mock_markdown.py
|
||||
"""
|
||||
from metagpt.actions import UserRequirement, WriteDesign, WritePRD, WriteTasks
|
||||
from metagpt.schema import Message
|
||||
|
|
|
|||
100
tests/metagpt/roles/test_assistant.py
Normal file
100
tests/metagpt/roles/test_assistant.py
Normal file
|
|
@ -0,0 +1,100 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/12/25
|
||||
@Author : mashenquan
|
||||
@File : test_asssistant.py
|
||||
@Desc : Used by AgentStore.
|
||||
"""
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from metagpt.actions.skill_action import SkillAction
|
||||
from metagpt.actions.talk_action import TalkAction
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.logs import logger
|
||||
from metagpt.memory.brain_memory import BrainMemory
|
||||
from metagpt.roles.assistant import Assistant
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.common import any_to_str
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run():
|
||||
CONFIG.language = "Chinese"
|
||||
|
||||
class Input(BaseModel):
|
||||
memory: BrainMemory
|
||||
language: str
|
||||
agent_description: str
|
||||
cause_by: str
|
||||
|
||||
inputs = [
|
||||
{
|
||||
"memory": {
|
||||
"history": [
|
||||
{
|
||||
"content": "who is tulin",
|
||||
"role": "user",
|
||||
"id": 1,
|
||||
},
|
||||
{"content": "The one who eaten a poison apple.", "role": "assistant"},
|
||||
],
|
||||
"knowledge": [{"content": "tulin is a scientist."}],
|
||||
"last_talk": "what's apple?",
|
||||
},
|
||||
"language": "English",
|
||||
"agent_description": "chatterbox",
|
||||
"cause_by": any_to_str(TalkAction),
|
||||
},
|
||||
{
|
||||
"memory": {
|
||||
"history": [
|
||||
{
|
||||
"content": "can you draw me an picture?",
|
||||
"role": "user",
|
||||
"id": 1,
|
||||
},
|
||||
{"content": "Yes, of course. What do you want me to draw", "role": "assistant"},
|
||||
],
|
||||
"knowledge": [{"content": "tulin is a scientist."}],
|
||||
"last_talk": "Draw me an apple.",
|
||||
},
|
||||
"language": "English",
|
||||
"agent_description": "painter",
|
||||
"cause_by": any_to_str(SkillAction),
|
||||
},
|
||||
]
|
||||
CONFIG.agent_skills = [
|
||||
{"id": 1, "name": "text_to_speech", "type": "builtin", "config": {}, "enabled": True},
|
||||
{"id": 2, "name": "text_to_image", "type": "builtin", "config": {}, "enabled": True},
|
||||
{"id": 3, "name": "ai_call", "type": "builtin", "config": {}, "enabled": True},
|
||||
{"id": 3, "name": "data_analysis", "type": "builtin", "config": {}, "enabled": True},
|
||||
{"id": 5, "name": "crawler", "type": "builtin", "config": {"engine": "ddg"}, "enabled": True},
|
||||
{"id": 6, "name": "knowledge", "type": "builtin", "config": {}, "enabled": True},
|
||||
{"id": 6, "name": "web_search", "type": "builtin", "config": {}, "enabled": True},
|
||||
]
|
||||
|
||||
for i in inputs:
|
||||
seed = Input(**i)
|
||||
CONFIG.language = seed.language
|
||||
CONFIG.agent_description = seed.agent_description
|
||||
role = Assistant(language="Chinese")
|
||||
role.memory = seed.memory # Restore historical conversation content.
|
||||
while True:
|
||||
has_action = await role.think()
|
||||
if not has_action:
|
||||
break
|
||||
msg: Message = await role.act()
|
||||
logger.info(msg)
|
||||
assert msg
|
||||
assert msg.cause_by == seed.cause_by
|
||||
assert msg.content
|
||||
# # Retrieve user terminal input.
|
||||
# logger.info("Enter prompt")
|
||||
# talk = input("You: ")
|
||||
# await role.talk(talk)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
157
tests/metagpt/roles/test_teacher.py
Normal file
157
tests/metagpt/roles/test_teacher.py
Normal file
|
|
@ -0,0 +1,157 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/7/27 13:25
|
||||
@Author : mashenquan
|
||||
@File : test_teacher.py
|
||||
"""
|
||||
import os
|
||||
from typing import Dict, Optional
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from metagpt.config import CONFIG, Config
|
||||
from metagpt.roles.teacher import Teacher
|
||||
from metagpt.schema import Message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_init():
|
||||
class Inputs(BaseModel):
|
||||
name: str
|
||||
profile: str
|
||||
goal: str
|
||||
constraints: str
|
||||
desc: str
|
||||
kwargs: Optional[Dict] = None
|
||||
expect_name: str
|
||||
expect_profile: str
|
||||
expect_goal: str
|
||||
expect_constraints: str
|
||||
expect_desc: str
|
||||
|
||||
inputs = [
|
||||
{
|
||||
"name": "Lily{language}",
|
||||
"expect_name": "Lily{language}",
|
||||
"profile": "X {teaching_language}",
|
||||
"expect_profile": "X {teaching_language}",
|
||||
"goal": "Do {something_big}, {language}",
|
||||
"expect_goal": "Do {something_big}, {language}",
|
||||
"constraints": "Do in {key1}, {language}",
|
||||
"expect_constraints": "Do in {key1}, {language}",
|
||||
"kwargs": {},
|
||||
"desc": "aaa{language}",
|
||||
"expect_desc": "aaa{language}",
|
||||
},
|
||||
{
|
||||
"name": "Lily{language}",
|
||||
"expect_name": "LilyCN",
|
||||
"profile": "X {teaching_language}",
|
||||
"expect_profile": "X EN",
|
||||
"goal": "Do {something_big}, {language}",
|
||||
"expect_goal": "Do sleep, CN",
|
||||
"constraints": "Do in {key1}, {language}",
|
||||
"expect_constraints": "Do in HaHa, CN",
|
||||
"kwargs": {"language": "CN", "key1": "HaHa", "something_big": "sleep", "teaching_language": "EN"},
|
||||
"desc": "aaa{language}",
|
||||
"expect_desc": "aaaCN",
|
||||
},
|
||||
]
|
||||
|
||||
env = os.environ.copy()
|
||||
for i in inputs:
|
||||
seed = Inputs(**i)
|
||||
os.environ.clear()
|
||||
os.environ.update(env)
|
||||
CONFIG = Config()
|
||||
CONFIG.set_context(seed.kwargs)
|
||||
print(CONFIG.options)
|
||||
assert bool("language" in seed.kwargs) == bool("language" in CONFIG.options)
|
||||
|
||||
teacher = Teacher(
|
||||
name=seed.name,
|
||||
profile=seed.profile,
|
||||
goal=seed.goal,
|
||||
constraints=seed.constraints,
|
||||
desc=seed.desc,
|
||||
)
|
||||
assert teacher.name == seed.expect_name
|
||||
assert teacher.desc == seed.expect_desc
|
||||
assert teacher.profile == seed.expect_profile
|
||||
assert teacher.goal == seed.expect_goal
|
||||
assert teacher.constraints == seed.expect_constraints
|
||||
assert teacher.course_title == "teaching_plan"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_file_name():
|
||||
class Inputs(BaseModel):
|
||||
lesson_title: str
|
||||
ext: str
|
||||
expect: str
|
||||
|
||||
inputs = [
|
||||
{"lesson_title": "# @344\n12", "ext": ".md", "expect": "_344_12.md"},
|
||||
{"lesson_title": "1#@$%!*&\\/:*?\"<>|\n\t '1", "ext": ".cc", "expect": "1_1.cc"},
|
||||
]
|
||||
for i in inputs:
|
||||
seed = Inputs(**i)
|
||||
result = Teacher.new_file_name(seed.lesson_title, seed.ext)
|
||||
assert result == seed.expect
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run():
|
||||
CONFIG.set_context({"language": "Chinese", "teaching_language": "English"})
|
||||
lesson = """
|
||||
UNIT 1 Making New Friends
|
||||
TOPIC 1 Welcome to China!
|
||||
Section A
|
||||
|
||||
1a Listen and number the following names.
|
||||
Jane Mari Kangkang Michael
|
||||
Look, listen and understand. Then practice the conversation.
|
||||
Work in groups. Introduce yourself using
|
||||
I ’m ... Then practice 1a
|
||||
with your own hometown or the following places.
|
||||
|
||||
1b Listen and number the following names
|
||||
Jane Michael Maria Kangkang
|
||||
1c Work in groups. Introduce yourself using I ’m ... Then practice 1a with your own hometown or the following places.
|
||||
China the USA the UK Hong Kong Beijing
|
||||
|
||||
2a Look, listen and understand. Then practice the conversation
|
||||
Hello!
|
||||
Hello!
|
||||
Hello!
|
||||
Hello! Are you Maria?
|
||||
No, I’m not. I’m Jane.
|
||||
Oh, nice to meet you, Jane
|
||||
Nice to meet you, too.
|
||||
Hi, Maria!
|
||||
Hi, Kangkang!
|
||||
Welcome to China!
|
||||
Thanks.
|
||||
|
||||
2b Work in groups. Make up a conversation with your own name and the
|
||||
following structures.
|
||||
A: Hello! / Good morning! / Hi! I’m ... Are you ... ?
|
||||
B: ...
|
||||
|
||||
3a Listen, say and trace
|
||||
Aa Bb Cc Dd Ee Ff Gg
|
||||
|
||||
3b Listen and number the following letters. Then circle the letters with the same sound as Bb.
|
||||
Aa Bb Cc Dd Ee Ff Gg
|
||||
|
||||
3c Match the big letters with the small ones. Then write them on the lines.
|
||||
"""
|
||||
teacher = Teacher()
|
||||
rsp = await teacher.run(Message(content=lesson))
|
||||
assert rsp
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
@ -8,8 +8,6 @@ from functools import wraps
|
|||
from importlib import import_module
|
||||
|
||||
from metagpt.actions import Action, ActionOutput, WritePRD
|
||||
|
||||
# from metagpt.const import WORKSPACE_ROOT
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.logs import logger
|
||||
|
|
|
|||
|
|
@ -93,4 +93,8 @@ async def test_role_serdeser_interrupt():
|
|||
assert new_role_a._rc.state == 1
|
||||
|
||||
with pytest.raises(Exception):
|
||||
await role_c.run(with_message=Message(content="demo", cause_by=UserRequirement))
|
||||
await new_role_a.run(with_message=Message(content="demo", cause_by=UserRequirement))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
|
|
@ -85,3 +85,4 @@ class RoleC(Role):
|
|||
self._init_actions([ActionOK, ActionRaise])
|
||||
self._watch([UserRequirement])
|
||||
self._rc.react_mode = RoleReactMode.BY_ORDER
|
||||
self._rc.memory.ignore_id = True
|
||||
|
|
|
|||
|
|
@ -4,6 +4,8 @@
|
|||
@Time : 2023/5/12 00:47
|
||||
@Author : alexanderwu
|
||||
@File : test_environment.py
|
||||
@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation.
|
||||
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
|
@ -11,9 +13,9 @@ from pathlib import Path
|
|||
import pytest
|
||||
|
||||
from metagpt.actions import UserRequirement
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.environment import Environment
|
||||
from metagpt.logs import logger
|
||||
from metagpt.manager import Manager
|
||||
from metagpt.roles import Architect, ProductManager, Role
|
||||
from metagpt.schema import Message
|
||||
|
||||
|
|
@ -44,6 +46,10 @@ def test_get_roles(env: Environment):
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_and_process_message(env: Environment):
|
||||
if CONFIG.git_repo:
|
||||
CONFIG.git_repo.delete_repository()
|
||||
CONFIG.git_repo = None
|
||||
|
||||
product_manager = ProductManager(name="Alice", profile="Product Manager", goal="做AI Native产品", constraints="资源有限")
|
||||
architect = Architect(
|
||||
name="Bob", profile="Architect", goal="设计一个可用、高效、较低成本的系统,包括数据结构与接口", constraints="资源有限,需要节省成本"
|
||||
|
|
@ -51,9 +57,11 @@ async def test_publish_and_process_message(env: Environment):
|
|||
|
||||
env.add_roles([product_manager, architect])
|
||||
|
||||
env.set_manager(Manager())
|
||||
env.publish_message(Message(role="User", content="需要一个基于LLM做总结的搜索引擎", cause_by=UserRequirement))
|
||||
|
||||
await env.run(k=2)
|
||||
logger.info(f"{env.history=}")
|
||||
assert len(env.history) > 10
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
|
|
@ -5,9 +5,10 @@
|
|||
@Author : alexanderwu
|
||||
@File : test_gpt.py
|
||||
"""
|
||||
|
||||
import openai
|
||||
import pytest
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.logs import logger
|
||||
|
||||
|
||||
|
|
@ -18,34 +19,44 @@ class TestGPT:
|
|||
logger.info(answer)
|
||||
assert len(answer) > 0
|
||||
|
||||
# def test_gptapi_ask_batch(self, llm_api):
|
||||
# answer = llm_api.ask_batch(['请扮演一个Google Python专家工程师,如果理解,回复明白', '写一个hello world'])
|
||||
# assert len(answer) > 0
|
||||
def test_gptapi_ask_batch(self, llm_api):
|
||||
answer = llm_api.ask_batch(["请扮演一个Google Python专家工程师,如果理解,回复明白", "写一个hello world"], timeout=60)
|
||||
assert len(answer) > 0
|
||||
|
||||
def test_llm_api_ask_code(self, llm_api):
|
||||
answer = llm_api.ask_code(["请扮演一个Google Python专家工程师,如果理解,回复明白", "写一个hello world"])
|
||||
logger.info(answer)
|
||||
assert len(answer) > 0
|
||||
try:
|
||||
answer = llm_api.ask_code(["请扮演一个Google Python专家工程师,如果理解,回复明白", "写一个hello world"])
|
||||
logger.info(answer)
|
||||
assert len(answer) > 0
|
||||
except openai.BadRequestError:
|
||||
assert CONFIG.OPENAI_API_TYPE == "azure"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_api_aask(self, llm_api):
|
||||
answer = await llm_api.aask("hello chatgpt")
|
||||
answer = await llm_api.aask("hello chatgpt", stream=False)
|
||||
logger.info(answer)
|
||||
assert len(answer) > 0
|
||||
|
||||
answer = await llm_api.aask("hello chatgpt", stream=True)
|
||||
logger.info(answer)
|
||||
assert len(answer) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_api_aask_code(self, llm_api):
|
||||
answer = await llm_api.aask_code(["请扮演一个Google Python专家工程师,如果理解,回复明白", "写一个hello world"])
|
||||
logger.info(answer)
|
||||
assert len(answer) > 0
|
||||
try:
|
||||
answer = await llm_api.aask_code(["请扮演一个Google Python专家工程师,如果理解,回复明白", "写一个hello world"], timeout=60)
|
||||
logger.info(answer)
|
||||
assert len(answer) > 0
|
||||
except openai.BadRequestError:
|
||||
assert CONFIG.OPENAI_API_TYPE == "azure"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_api_costs(self, llm_api):
|
||||
await llm_api.aask("hello chatgpt")
|
||||
await llm_api.aask("hello chatgpt", stream=False)
|
||||
costs = llm_api.get_costs()
|
||||
logger.info(costs)
|
||||
assert costs.total_cost > 0
|
||||
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# pytest.main([__file__, "-s"])
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
|
|
@ -4,11 +4,12 @@
|
|||
@Time : 2023/5/11 14:45
|
||||
@Author : alexanderwu
|
||||
@File : test_llm.py
|
||||
@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.provider.openai_api import OpenAIGPTAPI as LLM
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
|
|
@ -18,7 +19,8 @@ def llm():
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_aask(llm):
|
||||
assert len(await llm.aask("hello world")) > 0
|
||||
rsp = await llm.aask("hello world", stream=False)
|
||||
assert len(rsp) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -29,10 +31,11 @@ async def test_llm_aask_batch(llm):
|
|||
@pytest.mark.asyncio
|
||||
async def test_llm_acompletion(llm):
|
||||
hello_msg = [{"role": "user", "content": "hello"}]
|
||||
assert len(await llm.acompletion(hello_msg)) > 0
|
||||
rsp = await llm.acompletion(hello_msg)
|
||||
assert len(rsp.choices[0].message.content) > 0
|
||||
assert len(await llm.acompletion_batch([hello_msg])) > 0
|
||||
assert len(await llm.acompletion_batch_text([hello_msg])) > 0
|
||||
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# pytest.main([__file__, "-s"])
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
0
tests/metagpt/test_repo_parser.py
Normal file
0
tests/metagpt/test_repo_parser.py
Normal file
|
|
@ -10,8 +10,6 @@
|
|||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.actions.write_code import WriteCode
|
||||
|
|
@ -19,7 +17,6 @@ from metagpt.schema import AIMessage, Message, SystemMessage, UserMessage
|
|||
from metagpt.utils.common import any_to_str
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
def test_messages():
|
||||
test_content = "test_message"
|
||||
msgs = [
|
||||
|
|
@ -33,7 +30,6 @@ def test_messages():
|
|||
assert all([i in text for i in roles])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
def test_message():
|
||||
m = Message(content="a", role="v1")
|
||||
v = m.dump()
|
||||
|
|
@ -64,7 +60,6 @@ def test_message():
|
|||
assert m.content == "b"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
def test_routes():
|
||||
m = Message(content="a", role="b", cause_by="c", x="d", send_to="c")
|
||||
m.send_to = "b"
|
||||
|
|
|
|||
|
|
@ -26,3 +26,7 @@ async def test_team():
|
|||
# def test_startup():
|
||||
# args = ["Make a 2048 game"]
|
||||
# result = runner.invoke(app, args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
|
|
@ -100,3 +100,7 @@ async def test_subscription_run_error(loguru_caplog):
|
|||
logs = "".join(loguru_caplog.messages)
|
||||
assert "run error" in logs
|
||||
assert "has completed" in logs
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
44
tests/metagpt/tools/test_azure_tts.py
Normal file
44
tests/metagpt/tools/test_azure_tts.py
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/7/1 22:50
|
||||
@Author : alexanderwu
|
||||
@File : test_azure_tts.py
|
||||
@Modified By: mashenquan, 2023-8-9, add more text formatting options
|
||||
@Modified By: mashenquan, 2023-8-17, move to `tools` folder.
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.tools.azure_tts import AzureTTS
|
||||
|
||||
|
||||
def test_azure_tts():
|
||||
azure_tts = AzureTTS(subscription_key="", region="")
|
||||
text = """
|
||||
女儿看见父亲走了进来,问道:
|
||||
<mstts:express-as role="YoungAdultFemale" style="calm">
|
||||
“您来的挺快的,怎么过来的?”
|
||||
</mstts:express-as>
|
||||
父亲放下手提包,说:
|
||||
<mstts:express-as role="OlderAdultMale" style="calm">
|
||||
“Writing a binary file in Python is similar to writing a regular text file, but you'll work with bytes instead of strings.”
|
||||
</mstts:express-as>
|
||||
"""
|
||||
path = CONFIG.workspace / "tts"
|
||||
path.mkdir(exist_ok=True, parents=True)
|
||||
filename = path / "girl.wav"
|
||||
loop = asyncio.new_event_loop()
|
||||
v = loop.create_task(
|
||||
azure_tts.synthesize_speech(lang="zh-CN", voice="zh-CN-XiaomoNeural", text=text, output_file=str(filename))
|
||||
)
|
||||
result = loop.run_until_complete(v)
|
||||
|
||||
print(result)
|
||||
|
||||
# 运行需要先配置 SUBSCRIPTION_KEY
|
||||
# TODO: 这里如果要检验,还要额外加上对应的asr,才能确保前后生成是接近一致的,但现在还没有
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_azure_tts()
|
||||
|
|
@ -1,5 +1,10 @@
|
|||
"""
|
||||
@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.config import Config
|
||||
from metagpt.tools import WebBrowserEngineType, web_browser_engine
|
||||
|
||||
|
||||
|
|
@ -13,7 +18,8 @@ from metagpt.tools import WebBrowserEngineType, web_browser_engine
|
|||
ids=["playwright", "selenium"],
|
||||
)
|
||||
async def test_scrape_web_page(browser_type, url, urls):
|
||||
browser = web_browser_engine.WebBrowserEngine(browser_type)
|
||||
conf = Config()
|
||||
browser = web_browser_engine.WebBrowserEngine(options=conf.runtime_options, engine=browser_type)
|
||||
result = await browser.run(url)
|
||||
assert isinstance(result, str)
|
||||
assert "深度赋智" in result
|
||||
|
|
|
|||
|
|
@ -1,6 +1,10 @@
|
|||
"""
|
||||
@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.config import Config
|
||||
from metagpt.tools import web_browser_engine_playwright
|
||||
|
||||
|
||||
|
|
@ -15,22 +19,25 @@ from metagpt.tools import web_browser_engine_playwright
|
|||
ids=["chromium-normal", "firefox-normal", "webkit-normal"],
|
||||
)
|
||||
async def test_scrape_web_page(browser_type, use_proxy, kwagrs, url, urls, proxy, capfd):
|
||||
conf = Config()
|
||||
global_proxy = conf.global_proxy
|
||||
try:
|
||||
global_proxy = CONFIG.global_proxy
|
||||
if use_proxy:
|
||||
CONFIG.global_proxy = proxy
|
||||
browser = web_browser_engine_playwright.PlaywrightWrapper(browser_type, **kwagrs)
|
||||
conf.global_proxy = proxy
|
||||
browser = web_browser_engine_playwright.PlaywrightWrapper(
|
||||
options=conf.runtime_options, browser_type=browser_type, **kwagrs
|
||||
)
|
||||
result = await browser.run(url)
|
||||
result = result.inner_text
|
||||
assert isinstance(result, str)
|
||||
assert "Deepwisdom" in result
|
||||
assert "DeepWisdom" in result
|
||||
|
||||
if urls:
|
||||
results = await browser.run(url, *urls)
|
||||
assert isinstance(results, list)
|
||||
assert len(results) == len(urls) + 1
|
||||
assert all(("Deepwisdom" in i) for i in results)
|
||||
assert all(("DeepWisdom" in i) for i in results)
|
||||
if use_proxy:
|
||||
assert "Proxy:" in capfd.readouterr().out
|
||||
finally:
|
||||
CONFIG.global_proxy = global_proxy
|
||||
conf.global_proxy = global_proxy
|
||||
|
|
|
|||
|
|
@ -1,6 +1,10 @@
|
|||
"""
|
||||
@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.config import Config
|
||||
from metagpt.tools import web_browser_engine_selenium
|
||||
|
||||
|
||||
|
|
@ -15,11 +19,12 @@ from metagpt.tools import web_browser_engine_selenium
|
|||
ids=["chrome-normal", "firefox-normal", "edge-normal"],
|
||||
)
|
||||
async def test_scrape_web_page(browser_type, use_proxy, url, urls, proxy, capfd):
|
||||
conf = Config()
|
||||
global_proxy = conf.global_proxy
|
||||
try:
|
||||
global_proxy = CONFIG.global_proxy
|
||||
if use_proxy:
|
||||
CONFIG.global_proxy = proxy
|
||||
browser = web_browser_engine_selenium.SeleniumWrapper(browser_type)
|
||||
conf.global_proxy = proxy
|
||||
browser = web_browser_engine_selenium.SeleniumWrapper(options=conf.runtime_options, browser_type=browser_type)
|
||||
result = await browser.run(url)
|
||||
result = result.inner_text
|
||||
assert isinstance(result, str)
|
||||
|
|
@ -33,4 +38,4 @@ async def test_scrape_web_page(browser_type, use_proxy, url, urls, proxy, capfd)
|
|||
if use_proxy:
|
||||
assert "Proxy:" in capfd.readouterr().out
|
||||
finally:
|
||||
CONFIG.global_proxy = global_proxy
|
||||
conf.global_proxy = global_proxy
|
||||
|
|
|
|||
|
|
@ -4,19 +4,15 @@
|
|||
@Time : 2023/5/1 11:19
|
||||
@Author : alexanderwu
|
||||
@File : test_config.py
|
||||
@Modified By: mashenquan, 2013/8/20, Add `test_options`; remove global configuration `CONFIG`, enable configuration support for business isolation.
|
||||
"""
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.config import Config
|
||||
|
||||
|
||||
def test_config_class_is_singleton():
|
||||
config_1 = Config()
|
||||
config_2 = Config()
|
||||
assert config_1 == config_2
|
||||
|
||||
|
||||
def test_config_class_get_key_exception():
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
config = Config()
|
||||
|
|
@ -28,4 +24,14 @@ def test_config_yaml_file_not_exists():
|
|||
config = Config("wtf.yaml")
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
config.get("OPENAI_BASE_URL")
|
||||
assert str(exc_info.value) == "Key 'OPENAI_BASE_URL' not found in environment variables or in the YAML file"
|
||||
assert str(exc_info.value) == "Set OPENAI_API_KEY or Anthropic_API_KEY first"
|
||||
|
||||
|
||||
def test_options():
|
||||
filename = Path(__file__).resolve().parent.parent.parent.parent / "config/config.yaml"
|
||||
config = Config(filename)
|
||||
assert config.options
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_options()
|
||||
|
|
|
|||
85
tests/metagpt/utils/test_di_graph_repository.py
Normal file
85
tests/metagpt/utils/test_di_graph_repository.py
Normal file
|
|
@ -0,0 +1,85 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/12/19
|
||||
@Author : mashenquan
|
||||
@File : test_di_graph_repository.py
|
||||
@Desc : Unit tests for di_graph_repository.py
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from metagpt.const import DEFAULT_WORKSPACE_ROOT
|
||||
from metagpt.repo_parser import RepoParser
|
||||
from metagpt.utils.di_graph_repository import DiGraphRepository
|
||||
from metagpt.utils.graph_repository import GraphRepository
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_di_graph_repository():
|
||||
class Input(BaseModel):
|
||||
s: str
|
||||
p: str
|
||||
o: str
|
||||
|
||||
inputs = [
|
||||
{"s": "main.py:Game:draw", "p": "method:hasDescription", "o": "Draw image"},
|
||||
{"s": "main.py:Game:draw", "p": "method:hasDescription", "o": "Show image"},
|
||||
]
|
||||
path = Path(__file__).parent
|
||||
graph = DiGraphRepository(name="test", root=path)
|
||||
for i in inputs:
|
||||
data = Input(**i)
|
||||
await graph.insert(subject=data.s, predicate=data.p, object_=data.o)
|
||||
v = graph.json()
|
||||
assert v
|
||||
await graph.save()
|
||||
assert graph.pathname.exists()
|
||||
graph.pathname.unlink()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_js_parser():
|
||||
class Input(BaseModel):
|
||||
path: str
|
||||
|
||||
inputs = [
|
||||
{"path": str(Path(__file__).parent / "../../data/code")},
|
||||
]
|
||||
path = Path(__file__).parent
|
||||
graph = DiGraphRepository(name="test", root=path)
|
||||
for i in inputs:
|
||||
data = Input(**i)
|
||||
repo_parser = RepoParser(base_directory=data.path)
|
||||
symbols = repo_parser.generate_symbols()
|
||||
for s in symbols:
|
||||
await GraphRepository.update_graph_db(graph_db=graph, file_info=s)
|
||||
data = graph.json()
|
||||
assert data
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_codes():
|
||||
path = DEFAULT_WORKSPACE_ROOT / "snake_game"
|
||||
repo_parser = RepoParser(base_directory=path)
|
||||
|
||||
graph = DiGraphRepository(name="test", root=path)
|
||||
symbols = repo_parser.generate_symbols()
|
||||
for file_info in symbols:
|
||||
for code_block in file_info.page_info:
|
||||
try:
|
||||
val = code_block.json(ensure_ascii=False)
|
||||
assert val
|
||||
except TypeError as e:
|
||||
assert not e
|
||||
await GraphRepository.update_graph_db(graph_db=graph, file_info=file_info)
|
||||
data = graph.json()
|
||||
assert data
|
||||
print(data)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
@ -15,7 +15,7 @@ def test_count_message_tokens():
|
|||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
]
|
||||
assert count_message_tokens(messages) == 17
|
||||
assert count_message_tokens(messages) == 15
|
||||
|
||||
|
||||
def test_count_message_tokens_with_name():
|
||||
|
|
@ -67,3 +67,7 @@ def test_count_string_tokens_gpt_4():
|
|||
|
||||
string = "Hello, world!"
|
||||
assert count_string_tokens(string, model_name="gpt-4-0314") == 4
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue