Merge branch 'geekan:main' into feat_st_game

This commit is contained in:
better629 2024-03-26 09:51:02 +08:00 committed by GitHub
commit 64350d2c6d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
303 changed files with 9155 additions and 3587 deletions

View file

@ -12,8 +12,10 @@ import logging
import os
import re
import uuid
from pathlib import Path
from typing import Callable
import aiohttp.web
import pytest
from metagpt.const import DEFAULT_WORKSPACE_ROOT, TEST_DATA_PATH
@ -111,12 +113,13 @@ def proxy():
while not reader.at_eof():
writer.write(await reader.read(2048))
writer.close()
await writer.wait_closed()
async def handle_client(reader, writer):
data = await reader.readuntil(b"\r\n\r\n")
print(f"Proxy: {data}") # checking with capfd fixture
infos = pattern.match(data)
host, port = infos.group("host"), infos.group("port")
print(f"Proxy: {host}") # checking with capfd fixture
port = int(port) if port else 80
remote_reader, remote_writer = await asyncio.open_connection(host, port)
if data.startswith(b"CONNECT"):
@ -171,9 +174,8 @@ def new_filename(mocker):
yield mocker
@pytest.fixture(scope="session")
def search_rsp_cache():
rsp_cache_file_path = TEST_DATA_PATH / "search_rsp_cache.json" # read repo-provided
def _rsp_cache(name):
rsp_cache_file_path = TEST_DATA_PATH / f"{name}.json" # read repo-provided
if os.path.exists(rsp_cache_file_path):
with open(rsp_cache_file_path, "r") as f1:
rsp_cache_json = json.load(f1)
@ -184,6 +186,16 @@ def search_rsp_cache():
json.dump(rsp_cache_json, f2, indent=4, ensure_ascii=False)
@pytest.fixture(scope="session")
def search_rsp_cache():
yield from _rsp_cache("search_rsp_cache")
@pytest.fixture(scope="session")
def mermaid_rsp_cache():
yield from _rsp_cache("mermaid_rsp_cache")
@pytest.fixture
def aiohttp_mocker(mocker):
MockResponse = type("MockResponse", (MockAioResponse,), {})
@ -231,3 +243,40 @@ def search_engine_mocker(aiohttp_mocker, curl_cffi_mocker, httplib2_mocker, sear
aiohttp_mocker.rsp_cache = httplib2_mocker.rsp_cache = curl_cffi_mocker.rsp_cache = search_rsp_cache
aiohttp_mocker.check_funcs = httplib2_mocker.check_funcs = curl_cffi_mocker.check_funcs = check_funcs
yield check_funcs
@pytest.fixture
def http_server():
async def handler(request):
return aiohttp.web.Response(
text="""<!DOCTYPE html><html lang="en"><head><meta charset="UTF-8">
<title>MetaGPT</title></head><body><h1>MetaGPT</h1></body></html>""",
content_type="text/html",
)
async def start():
server = aiohttp.web.Server(handler)
runner = aiohttp.web.ServerRunner(server)
await runner.setup()
site = aiohttp.web.TCPSite(runner, "127.0.0.1", 0)
await site.start()
_, port, *_ = site._server.sockets[0].getsockname()
return site, f"http://127.0.0.1:{port}"
return start
@pytest.fixture
def mermaid_mocker(aiohttp_mocker, mermaid_rsp_cache):
check_funcs: dict[tuple[str, str], Callable[[dict], str]] = {}
aiohttp_mocker.rsp_cache = mermaid_rsp_cache
aiohttp_mocker.check_funcs = check_funcs
yield check_funcs
@pytest.fixture
def git_dir():
"""Fixture to get the unittest directory."""
git_dir = Path(__file__).parent / f"unittest/{uuid.uuid4().hex}"
git_dir.mkdir(parents=True, exist_ok=True)
return git_dir

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -149,7 +149,7 @@ sequenceDiagram
The requirement analysis suggests the need for a clean and intuitive interface. Since we are using a command-line interface, we need to ensure that the text-based UI is as user-friendly as possible. Further clarification on whether a graphical user interface (GUI) is expected in the future would be helpful for planning the extendability of the game."""
TASKS_SAMPLE = """
TASK_SAMPLE = """
## Required Python packages
- random==2.2.1
@ -345,7 +345,7 @@ REFINED_DESIGN_JSON = {
"Anything UNCLEAR": "",
}
REFINED_TASKS_JSON = {
REFINED_TASK_JSON = {
"Required Python packages": ["random==2.2.1", "Tkinter==8.6"],
"Required Other language third-party packages": ["No third-party dependencies required"],
"Refined Logic Analysis": [
@ -373,7 +373,14 @@ REFINED_TASKS_JSON = {
}
CODE_PLAN_AND_CHANGE_SAMPLE = {
"Code Plan And Change": '\n1. Plan for gui.py: Develop the GUI using Tkinter to replace the command-line interface. Start by setting up the main window and event handling. Then, add widgets for displaying the game status, results, and feedback. Implement interactive elements for difficulty selection and visualize the guess history. Finally, create animations for guess feedback and ensure responsiveness across different screen sizes.\n```python\nclass GUI:\n- pass\n+ def __init__(self):\n+ self.setup_window()\n+\n+ def setup_window(self):\n+ # Initialize the main window using Tkinter\n+ pass\n+\n+ def bind_events(self):\n+ # Bind button clicks and other events\n+ pass\n+\n+ def update_feedback(self, message: str):\n+ # Update the feedback label with the given message\n+ pass\n+\n+ def update_attempts(self, attempts: int):\n+ # Update the attempts label with the number of attempts\n+ pass\n+\n+ def update_history(self, history: list):\n+ # Update the history view with the list of past guesses\n+ pass\n+\n+ def show_difficulty_selector(self):\n+ # Show buttons or a dropdown for difficulty selection\n+ pass\n+\n+ def animate_guess_result(self, correct: bool):\n+ # Trigger an animation for correct or incorrect guesses\n+ pass\n```\n\n2. Plan for main.py: Modify the main.py to initialize the GUI and start the event-driven game loop. Ensure that the GUI is the primary interface for user interaction.\n```python\nclass Main:\n def main(self):\n- user_interface = UI()\n- user_interface.start()\n+ graphical_user_interface = GUI()\n+ graphical_user_interface.setup_window()\n+ graphical_user_interface.bind_events()\n+ # Start the Tkinter main loop\n+ pass\n\n if __name__ == "__main__":\n main_instance = Main()\n main_instance.main()\n```\n\n3. Plan for ui.py: Refactor ui.py to work with the new GUI class. Remove command-line interactions and delegate display and input tasks to the GUI.\n```python\nclass UI:\n- def display_message(self, message: str):\n- print(message)\n+\n+ def display_message(self, message: str):\n+ # This method will now pass the message to the GUI to display\n+ pass\n\n- def get_user_input(self, prompt: str) -> str:\n- return input(prompt)\n+\n+ def get_user_input(self, prompt: str) -> str:\n+ # This method will now trigger the GUI to get user input\n+ pass\n\n- def show_attempts(self, attempts: int):\n- print(f"Number of attempts: {attempts}")\n+\n+ def show_attempts(self, attempts: int):\n+ # This method will now update the GUI with the number of attempts\n+ pass\n\n- def show_history(self, history: list):\n- print("Guess history:")\n- for guess in history:\n- print(guess)\n+\n+ def show_history(self, history: list):\n+ # This method will now update the GUI with the guess history\n+ pass\n```\n\n4. Plan for game.py: Ensure game.py remains mostly unchanged as it contains the core game logic. However, make minor adjustments if necessary to integrate with the new GUI.\n```python\nclass Game:\n # No changes required for now\n```\n'
"Development Plan": [
"Develop the GUI using Tkinter to replace the command-line interface. Start by setting up the main window and event handling. Then, add widgets for displaying the game status, results, and feedback. Implement interactive elements for difficulty selection and visualize the guess history. Finally, create animations for guess feedback and ensure responsiveness across different screen sizes.",
"Modify the main.py to initialize the GUI and start the event-driven game loop. Ensure that the GUI is the primary interface for user interaction.",
],
"Incremental Change": [
"""```diff\nclass GUI:\n- pass\n+ def __init__(self):\n+ self.setup_window()\n+\n+ def setup_window(self):\n+ # Initialize the main window using Tkinter\n+ pass\n+\n+ def bind_events(self):\n+ # Bind button clicks and other events\n+ pass\n+\n+ def update_feedback(self, message: str):\n+ # Update the feedback label with the given message\n+ pass\n+\n+ def update_attempts(self, attempts: int):\n+ # Update the attempts label with the number of attempts\n+ pass\n+\n+ def update_history(self, history: list):\n+ # Update the history view with the list of past guesses\n+ pass\n+\n+ def show_difficulty_selector(self):\n+ # Show buttons or a dropdown for difficulty selection\n+ pass\n+\n+ def animate_guess_result(self, correct: bool):\n+ # Trigger an animation for correct or incorrect guesses\n+ pass\n```""",
"""```diff\nclass Main:\n def main(self):\n- user_interface = UI()\n- user_interface.start()\n+ graphical_user_interface = GUI()\n+ graphical_user_interface.setup_window()\n+ graphical_user_interface.bind_events()\n+ # Start the Tkinter main loop\n+ pass\n\n if __name__ == "__main__":\n main_instance = Main()\n main_instance.main()\n```\n\n3. Plan for ui.py: Refactor ui.py to work with the new GUI class. Remove command-line interactions and delegate display and input tasks to the GUI.\n```python\nclass UI:\n- def display_message(self, message: str):\n- print(message)\n+\n+ def display_message(self, message: str):\n+ # This method will now pass the message to the GUI to display\n+ pass\n\n- def get_user_input(self, prompt: str) -> str:\n- return input(prompt)\n+\n+ def get_user_input(self, prompt: str) -> str:\n+ # This method will now trigger the GUI to get user input\n+ pass\n\n- def show_attempts(self, attempts: int):\n- print(f"Number of attempts: {attempts}")\n+\n+ def show_attempts(self, attempts: int):\n+ # This method will now update the GUI with the number of attempts\n+ pass\n\n- def show_history(self, history: list):\n- print("Guess history:")\n- for guess in history:\n- print(guess)\n+\n+ def show_history(self, history: list):\n+ # This method will now update the GUI with the guess history\n+ pass\n```\n\n4. Plan for game.py: Ensure game.py remains mostly unchanged as it contains the core game logic. However, make minor adjustments if necessary to integrate with the new GUI.\n```python\nclass Game:\n # No changes required for now\n```\n""",
],
}
REFINED_CODE_INPUT_SAMPLE = """

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -1,51 +0,0 @@
# -*- coding: utf-8 -*-
# @Date : 1/11/2024 8:51 PM
# @Author : stellahong (stellahong@fuzhi.ai)
# @Desc :
import pytest
from metagpt.actions.ci.debug_code import DebugCode
from metagpt.schema import Message
ErrorStr = """Tested passed:
Tests failed:
assert sort_array([1, 5, 2, 3, 4]) == [1, 2, 3, 4, 5] # output: [1, 2, 4, 3, 5]
"""
CODE = """
def sort_array(arr):
# Helper function to count the number of ones in the binary representation
def count_ones(n):
return bin(n).count('1')
# Sort the array using a custom key function
# The key function returns a tuple (number of ones, value) for each element
# This ensures that if two elements have the same number of ones, they are sorted by their value
sorted_arr = sorted(arr, key=lambda x: (count_ones(x), x))
return sorted_arr
```
"""
DebugContext = '''Solve the problem in Python:
def sort_array(arr):
"""
In this Kata, you have to sort an array of non-negative integers according to
number of ones in their binary representation in ascending order.
For similar number of ones, sort based on decimal value.
It must be implemented like this:
>>> sort_array([1, 5, 2, 3, 4]) == [1, 2, 3, 4, 5]
>>> sort_array([-2, -3, -4, -5, -6]) == [-6, -5, -4, -3, -2]
>>> sort_array([1, 0, 2, 3, 4]) [0, 1, 2, 3, 4]
"""
'''
@pytest.mark.asyncio
async def test_debug_code():
debug_context = Message(content=DebugContext)
new_code = await DebugCode().run(context=debug_context, code=CODE, runtime_result=ErrorStr)
assert "def sort_array(arr)" in new_code["code"]

View file

@ -1,46 +0,0 @@
import pytest
from metagpt.actions.ci.ml_action import WriteCodeWithToolsML
from metagpt.schema import Plan, Task
@pytest.mark.asyncio
async def test_write_code_with_tools():
write_code_ml = WriteCodeWithToolsML()
task_map = {
"1": Task(
task_id="1",
instruction="随机生成一个pandas DataFrame数据集",
task_type="other",
dependent_task_ids=[],
code="""
import pandas as pd
df = pd.DataFrame({
'a': [1, 2, 3, 4, 5],
'b': [1.1, 2.2, 3.3, 4.4, np.nan],
'c': ['aa', 'bb', 'cc', 'dd', 'ee'],
'd': [1, 2, 3, 4, 5]
})
""",
is_finished=True,
),
"2": Task(
task_id="2",
instruction="对数据集进行数据清洗",
task_type="data_preprocess",
dependent_task_ids=["1"],
),
}
plan = Plan(
goal="构造数据集并进行数据清洗",
tasks=list(task_map.values()),
task_map=task_map,
current_task_id="2",
)
column_info = ""
_, code_with_ml = await write_code_ml.run([], plan, column_info)
code_with_ml = code_with_ml["code"]
assert len(code_with_ml) > 0
print(code_with_ml)

View file

@ -1,324 +0,0 @@
import asyncio
import pytest
from metagpt.actions.ci.execute_nb_code import ExecuteNbCode
from metagpt.actions.ci.write_analysis_code import (
WriteCodeWithoutTools,
WriteCodeWithTools,
)
from metagpt.logs import logger
from metagpt.schema import Message, Plan, Task
from metagpt.strategy.planner import STRUCTURAL_CONTEXT
@pytest.mark.skip
@pytest.mark.asyncio
async def test_write_code_by_list_plan():
write_code = WriteCodeWithoutTools()
execute_code = ExecuteNbCode()
messages = []
plan = ["随机生成一个pandas DataFrame时间序列", "绘制这个时间序列的直方图", "回顾已完成的任务", "求均值", "总结"]
for task in plan:
print(f"\n任务: {task}\n\n")
messages.append(Message(task, role="assistant"))
code = await write_code.run(messages)
if task.startswith(("回顾", "总结")):
assert code["language"] == "markdown"
else:
assert code["language"] == "python"
messages.append(Message(code["code"], role="assistant"))
assert len(code) > 0
output, _ = await execute_code.run(**code)
print(f"\n[Output]: 任务{task}的执行结果是: \n{output}\n")
messages.append(output)
@pytest.mark.asyncio
async def test_tool_recommendation():
task = "clean and preprocess the data"
available_tools = {
"FillMissingValue": "Filling missing values",
"SplitBins": "Bin continuous data into intervals and return the bin identifier encoded as an integer value",
}
write_code = WriteCodeWithTools()
tools = await write_code._recommend_tool(task, available_tools)
assert len(tools) == 1
assert "FillMissingValue" in tools
@pytest.mark.asyncio
async def test_write_code_with_tools():
write_code = WriteCodeWithTools()
requirement = "构造数据集并进行数据清洗"
task_map = {
"1": Task(
task_id="1",
instruction="随机生成一个pandas DataFrame数据集",
task_type="other",
dependent_task_ids=[],
code="""
import pandas as pd
df = pd.DataFrame({
'a': [1, 2, 3, 4, 5],
'b': [1.1, 2.2, 3.3, 4.4, np.nan],
'c': ['aa', 'bb', 'cc', 'dd', 'ee'],
'd': [1, 2, 3, 4, 5]
})
""",
is_finished=True,
),
"2": Task(
task_id="2",
instruction="对数据集进行数据清洗",
task_type="data_preprocess",
dependent_task_ids=["1"],
),
}
plan = Plan(
goal="构造数据集并进行数据清洗",
tasks=list(task_map.values()),
task_map=task_map,
current_task_id="2",
)
context = STRUCTURAL_CONTEXT.format(
user_requirement=requirement,
context=plan.context,
tasks=list(task_map.values()),
current_task=plan.current_task.model_dump_json(),
)
context_msg = [Message(content=context, role="user")]
code = await write_code.run(context_msg, plan)
code = code["code"]
assert len(code) > 0
print(code)
@pytest.mark.asyncio
async def test_write_code_to_correct_error():
structural_context = """
## User Requirement
read a dataset test.csv and print its head
## Current Plan
[
{
"task_id": "1",
"dependent_task_ids": [],
"instruction": "import pandas and load the dataset from 'test.csv'.",
"task_type": "",
"code": "",
"result": "",
"is_finished": false
},
{
"task_id": "2",
"dependent_task_ids": [
"1"
],
"instruction": "Print the head of the dataset to display the first few rows.",
"task_type": "",
"code": "",
"result": "",
"is_finished": false
}
]
## Current Task
{"task_id": "1", "dependent_task_ids": [], "instruction": "import pandas and load the dataset from 'test.csv'.", "task_type": "", "code": "", "result": "", "is_finished": false}
"""
wrong_code = """import pandas as pd\ndata = pd.read_excel('test.csv')\ndata""" # use read_excel to read a csv
error = """
Traceback (most recent call last):
File "<stdin>", line 2, in <module>
File "/Users/gary/miniconda3/envs/py39_scratch/lib/python3.9/site-packages/pandas/io/excel/_base.py", line 478, in read_excel
io = ExcelFile(io, storage_options=storage_options, engine=engine)
File "/Users/gary/miniconda3/envs/py39_scratch/lib/python3.9/site-packages/pandas/io/excel/_base.py", line 1500, in __init__
raise ValueError(
ValueError: Excel file format cannot be determined, you must specify an engine manually.
"""
context = [
Message(content=structural_context, role="user"),
Message(content=wrong_code, role="assistant"),
Message(content=error, role="user"),
]
new_code = await WriteCodeWithoutTools().run(context=context)
new_code = new_code["code"]
print(new_code)
assert "read_csv" in new_code # should correct read_excel to read_csv
@pytest.mark.asyncio
async def test_write_code_reuse_code_simple():
structural_context = """
## User Requirement
read a dataset test.csv and print its head
## Current Plan
[
{
"task_id": "1",
"dependent_task_ids": [],
"instruction": "import pandas and load the dataset from 'test.csv'.",
"task_type": "",
"code": "import pandas as pd\ndata = pd.read_csv('test.csv')",
"result": "",
"is_finished": true
},
{
"task_id": "2",
"dependent_task_ids": [
"1"
],
"instruction": "Print the head of the dataset to display the first few rows.",
"task_type": "",
"code": "",
"result": "",
"is_finished": false
}
]
## Current Task
{"task_id": "2", "dependent_task_ids": ["1"], "instruction": "Print the head of the dataset to display the first few rows.", "task_type": "", "code": "", "result": "", "is_finished": false}
"""
context = [
Message(content=structural_context, role="user"),
]
code = await WriteCodeWithoutTools().run(context=context)
code = code["code"]
print(code)
assert "pandas" not in code and "read_csv" not in code # should reuse import and read statement from previous one
@pytest.mark.skip
@pytest.mark.asyncio
async def test_write_code_reuse_code_long():
"""test code reuse for long context"""
structural_context = """
## User Requirement
Run data analysis on sklearn Iris dataset, include a plot
## Current Plan
[
{
"task_id": "1",
"dependent_task_ids": [],
"instruction": "Load the Iris dataset from sklearn.",
"task_type": "",
"code": "from sklearn.datasets import load_iris\niris_data = load_iris()\niris_data['data'][0:5], iris_data['target'][0:5]",
"result": "(array([[5.1, 3.5, 1.4, 0.2],\n [4.9, 3. , 1.4, 0.2],\n [4.7, 3.2, 1.3, 0.2],\n [4.6, 3.1, 1.5, 0.2],\n [5. , 3.6, 1.4, 0.2]]),\n array([0, 0, 0, 0, 0]))",
"is_finished": true
},
{
"task_id": "2",
"dependent_task_ids": [
"1"
],
"instruction": "Perform exploratory data analysis on the Iris dataset.",
"task_type": "",
"code": "",
"result": "",
"is_finished": false
},
{
"task_id": "3",
"dependent_task_ids": [
"2"
],
"instruction": "Create a plot visualizing the Iris dataset features.",
"task_type": "",
"code": "",
"result": "",
"is_finished": false
}
]
## Current Task
{"task_id": "2", "dependent_task_ids": ["1"], "instruction": "Perform exploratory data analysis on the Iris dataset.", "task_type": "", "code": "", "result": "", "is_finished": false}
"""
context = [
Message(content=structural_context, role="user"),
]
trials_num = 5
trials = [WriteCodeWithoutTools().run(context=context, temperature=0.0) for _ in range(trials_num)]
trial_results = await asyncio.gather(*trials)
print(*trial_results, sep="\n\n***\n\n")
success = [
"load_iris" not in result["code"] and "iris_data" in result["code"] for result in trial_results
] # should reuse iris_data from previous tasks
success_rate = sum(success) / trials_num
logger.info(f"success rate: {success_rate :.2f}")
assert success_rate >= 0.8
@pytest.mark.skip
@pytest.mark.asyncio
async def test_write_code_reuse_code_long_for_wine():
"""test code reuse for long context"""
structural_context = """
## User Requirement
Run data analysis on sklearn Wisconsin Breast Cancer dataset, include a plot, train a model to predict targets (20% as validation), and show validation accuracy
## Current Plan
[
{
"task_id": "1",
"dependent_task_ids": [],
"instruction": "Load the sklearn Wine recognition dataset and perform exploratory data analysis."
"task_type": "",
"code": "from sklearn.datasets import load_wine\n# Load the Wine recognition dataset\nwine_data = load_wine()\n# Perform exploratory data analysis\nwine_data.keys()",
"result": "Truncated to show only the last 1000 characters\ndict_keys(['data', 'target', 'frame', 'target_names', 'DESCR', 'feature_names'])",
"is_finished": true
},
{
"task_id": "2",
"dependent_task_ids": ["1"],
"instruction": "Create a plot to visualize some aspect of the wine dataset."
"task_type": "",
"code": "",
"result": "",
"is_finished": false
},
{
"task_id": "3",
"dependent_task_ids": ["1"],
"instruction": "Split the dataset into training and validation sets with a 20% validation size.",
"task_type": "",
"code": "",
"result": "",
"is_finished": false
},
{
"task_id": "4",
"dependent_task_ids": ["3"],
"instruction": "Train a model on the training set to predict wine class.",
"task_type": "",
"code": "",
"result": "",
"is_finished": false
},
{
"task_id": "5",
"dependent_task_ids": ["4"],
"instruction": "Evaluate the model on the validation set and report the accuracy.",
"task_type": "",
"code": "",
"result": "",
"is_finished": false
}
]
## Current Task
{"task_id": "2", "dependent_task_ids": ["1"], "instruction": "Create a plot to visualize some aspect of the Wine dataset.", "task_type": "", "code": "", "result": "", "is_finished": false}
"""
context = [
Message(content=structural_context, role="user"),
]
trials_num = 5
trials = [WriteCodeWithoutTools().run(context=context, temperature=0.0) for _ in range(trials_num)]
trial_results = await asyncio.gather(*trials)
print(*trial_results, sep="\n\n***\n\n")
success = [
"load_wine" not in result["code"] and "wine_data" in result["code"] for result in trial_results
] # should reuse iris_data from previous tasks
success_rate = sum(success) / trials_num
logger.info(f"success rate: {success_rate :.2f}")
assert success_rate >= 0.8

View file

@ -1,6 +1,6 @@
import pytest
from metagpt.actions.ci.ask_review import AskReview
from metagpt.actions.di.ask_review import AskReview
@pytest.mark.asyncio

View file

@ -1,6 +1,6 @@
import pytest
from metagpt.actions.ci.execute_nb_code import ExecuteNbCode, truncate
from metagpt.actions.di.execute_nb_code import ExecuteNbCode
@pytest.mark.asyncio
@ -8,6 +8,7 @@ async def test_code_running():
executor = ExecuteNbCode()
output, is_success = await executor.run("print('hello world!')")
assert is_success
await executor.terminate()
@pytest.mark.asyncio
@ -17,6 +18,7 @@ async def test_split_code_running():
_ = await executor.run("z=x+y")
output, is_success = await executor.run("assert z==3")
assert is_success
await executor.terminate()
@pytest.mark.asyncio
@ -24,6 +26,7 @@ async def test_execute_error():
executor = ExecuteNbCode()
output, is_success = await executor.run("z=1/0")
assert not is_success
await executor.terminate()
PLOT_CODE = """
@ -52,21 +55,7 @@ async def test_plotting_code():
executor = ExecuteNbCode()
output, is_success = await executor.run(PLOT_CODE)
assert is_success
def test_truncate():
# 代码执行成功
output, is_success = truncate("hello world", 5, True)
assert "Truncated to show only first 5 characters\nhello" in output
assert is_success
# 代码执行失败
output, is_success = truncate("hello world", 5, False)
assert "Truncated to show only last 5 characters\nworld" in output
assert not is_success
# 异步
output, is_success = truncate("<coroutine object", 5, True)
assert not is_success
assert "await" in output
await executor.terminate()
@pytest.mark.asyncio
@ -76,6 +65,7 @@ async def test_run_with_timeout():
message, success = await executor.run(code)
assert not success
assert message.startswith("Cell execution timed out")
await executor.terminate()
@pytest.mark.asyncio
@ -83,7 +73,7 @@ async def test_run_code_text():
executor = ExecuteNbCode()
message, success = await executor.run(code='print("This is a code!")', language="python")
assert success
assert message == "This is a code!\n"
assert "This is a code!" in message
message, success = await executor.run(code="# This is a code!", language="markdown")
assert success
assert message == "# This is a code!"
@ -91,19 +81,22 @@ async def test_run_code_text():
message, success = await executor.run(code=mix_text, language="markdown")
assert success
assert message == mix_text
await executor.terminate()
@pytest.mark.asyncio
async def test_terminate():
executor = ExecuteNbCode()
await executor.run(code='print("This is a code!")', language="python")
is_kernel_alive = await executor.nb_client.km.is_alive()
assert is_kernel_alive
await executor.terminate()
import time
time.sleep(2)
assert executor.nb_client.km is None
@pytest.mark.parametrize(
"k", [(1), (5)]
) # k=1 to test a single regular terminate, k>1 to test terminate under continuous run
async def test_terminate(k):
for _ in range(k):
executor = ExecuteNbCode()
await executor.run(code='print("This is a code!")', language="python")
is_kernel_alive = await executor.nb_client.km.is_alive()
assert is_kernel_alive
await executor.terminate()
assert executor.nb_client.km is None
assert executor.nb_client.kc is None
@pytest.mark.asyncio
@ -114,3 +107,22 @@ async def test_reset():
assert is_kernel_alive
await executor.reset()
assert executor.nb_client.km is None
await executor.terminate()
@pytest.mark.asyncio
async def test_parse_outputs():
executor = ExecuteNbCode()
code = """
import pandas as pd
df = pd.DataFrame({'ID': [1,2,3], 'NAME': ['a', 'b', 'c']})
print(df.columns)
print(f"columns num:{len(df.columns)}")
print(df['DUMMPY_ID'])
"""
output, is_success = await executor.run(code)
assert not is_success
assert "Index(['ID', 'NAME'], dtype='object')" in output
assert "KeyError: 'DUMMPY_ID'" in output
assert "columns num:2" in output
await executor.terminate()

View file

@ -0,0 +1,79 @@
import pytest
from metagpt.actions.di.write_analysis_code import WriteAnalysisCode
from metagpt.schema import Message
@pytest.mark.asyncio
async def test_write_code_with_plan():
write_code = WriteAnalysisCode()
user_requirement = "Run data analysis on sklearn Iris dataset, include a plot"
plan_status = "\n## Finished Tasks\n### code\n```python\n\n```\n\n### execution result\n\n\n## Current Task\nLoad the sklearn Iris dataset and perform exploratory data analysis\n\n## Task Guidance\nWrite complete code for 'Current Task'. And avoid duplicating code from 'Finished Tasks', such as repeated import of packages, reading data, etc.\nSpecifically, \nThe current task is about exploratory data analysis, please note the following:\n- Distinguish column types with `select_dtypes` for tailored analysis and visualization, such as correlation.\n- Remember to `import numpy as np` before using Numpy functions.\n\n"
code = await write_code.run(user_requirement=user_requirement, plan_status=plan_status)
assert len(code) > 0
assert "sklearn" in code
@pytest.mark.asyncio
async def test_write_code_with_tools():
write_code = WriteAnalysisCode()
user_requirement = "Preprocess sklearn Wine recognition dataset and train a model to predict wine class (20% as validation), and show validation accuracy."
tool_info = """
## Capabilities
- You can utilize pre-defined tools in any code lines from 'Available Tools' in the form of Python class or function.
- You can freely combine the use of any other public packages, like sklearn, numpy, pandas, etc..
## Available Tools:
Each tool is described in JSON format. When you call a tool, import the tool from its path first.
{'FillMissingValue': {'type': 'class', 'description': 'Completing missing values with simple strategies.', 'methods': {'__init__': {'type': 'function', 'description': 'Initialize self. ', 'signature': '(self, features: \'list\', strategy: "Literal[\'mean\', \'median\', \'most_frequent\', \'constant\']" = \'mean\', fill_value=None)', 'parameters': 'Args: features (list): Columns to be processed. strategy (Literal["mean", "median", "most_frequent", "constant"], optional): The imputation strategy, notice \'mean\' and \'median\' can only be used for numeric features. Defaults to \'mean\'. fill_value (int, optional): Fill_value is used to replace all occurrences of missing_values. Defaults to None.'}, 'fit': {'type': 'function', 'description': 'Fit a model to be used in subsequent transform. ', 'signature': "(self, df: 'pd.DataFrame')", 'parameters': 'Args: df (pd.DataFrame): The input DataFrame.'}, 'fit_transform': {'type': 'function', 'description': 'Fit and transform the input DataFrame. ', 'signature': "(self, df: 'pd.DataFrame') -> 'pd.DataFrame'", 'parameters': 'Args: df (pd.DataFrame): The input DataFrame. Returns: pd.DataFrame: The transformed DataFrame.'}, 'transform': {'type': 'function', 'description': 'Transform the input DataFrame with the fitted model. ', 'signature': "(self, df: 'pd.DataFrame') -> 'pd.DataFrame'", 'parameters': 'Args: df (pd.DataFrame): The input DataFrame. Returns: pd.DataFrame: The transformed DataFrame.'}}, 'tool_path': 'metagpt/tools/libs/data_preprocess.py'}
"""
code = await write_code.run(user_requirement=user_requirement, tool_info=tool_info)
assert len(code) > 0
assert "metagpt.tools.libs" in code
@pytest.mark.asyncio
async def test_debug_with_reflection():
user_requirement = "read a dataset test.csv and print its head"
plan_status = """
## Finished Tasks
### code
```python
```
### execution result
## Current Task
import pandas and load the dataset from 'test.csv'.
## Task Guidance
Write complete code for 'Current Task'. And avoid duplicating code from 'Finished Tasks', such as repeated import of packages, reading data, etc.
Specifically,
"""
wrong_code = """import pandas as pd\ndata = pd.read_excel('test.csv')\ndata""" # use read_excel to read a csv
error = """
Traceback (most recent call last):
File "<stdin>", line 2, in <module>
File "/Users/gary/miniconda3/envs/py39_scratch/lib/python3.9/site-packages/pandas/io/excel/_base.py", line 478, in read_excel
io = ExcelFile(io, storage_options=storage_options, engine=engine)
File "/Users/gary/miniconda3/envs/py39_scratch/lib/python3.9/site-packages/pandas/io/excel/_base.py", line 1500, in __init__
raise ValueError(
ValueError: Excel file format cannot be determined, you must specify an engine manually.
"""
working_memory = [
Message(content=wrong_code, role="assistant"),
Message(content=error, role="user"),
]
new_code = await WriteAnalysisCode().run(
user_requirement=user_requirement,
plan_status=plan_status,
working_memory=working_memory,
use_reflection=True,
)
assert "read_csv" in new_code # should correct read_excel to read_csv

View file

@ -1,6 +1,6 @@
import pytest
from metagpt.actions.ci.write_plan import (
from metagpt.actions.di.write_plan import (
Plan,
Task,
WritePlan,
@ -23,12 +23,10 @@ def test_precheck_update_plan_from_rsp():
@pytest.mark.asyncio
@pytest.mark.parametrize("use_tools", [(False), (True)])
async def test_write_plan(use_tools):
async def test_write_plan():
rsp = await WritePlan().run(
context=[Message("run analysis on sklearn iris dataset", role="user")], use_tools=use_tools
context=[Message("Run data analysis on sklearn Iris dataset, include a plot", role="user")]
)
assert "task_id" in rsp
assert "instruction" in rsp
assert "json" not in rsp # the output should be the content inside ```json ```

View file

@ -37,7 +37,7 @@ DESIGN = {
}
TASKS = {
TASK = {
"Required Python packages": ["pygame==2.0.1"],
"Required Other language third-party packages": ["No third-party dependencies required"],
"Logic Analysis": [

View file

@ -9,8 +9,10 @@
import pytest
from metagpt.actions.design_api import WriteDesign
from metagpt.llm import LLM
from metagpt.logs import logger
from metagpt.schema import Message
from tests.data.incremental_dev_project.mock import DESIGN_SAMPLE, REFINED_PRD_JSON
@pytest.mark.asyncio
@ -25,3 +27,16 @@ async def test_design_api(context):
logger.info(result)
assert result
@pytest.mark.asyncio
async def test_refined_design_api(context):
await context.repo.docs.prd.save(filename="1.txt", content=str(REFINED_PRD_JSON))
await context.repo.docs.system_design.save(filename="1.txt", content=DESIGN_SAMPLE)
design_api = WriteDesign(context=context, llm=LLM())
result = await design_api.run(Message(content="", instruct_content=None))
logger.info(result)
assert result

View file

@ -9,13 +9,19 @@
import pytest
from metagpt.actions.project_management import WriteTasks
from metagpt.llm import LLM
from metagpt.logs import logger
from metagpt.schema import Message
from tests.data.incremental_dev_project.mock import (
REFINED_DESIGN_JSON,
REFINED_PRD_JSON,
TASK_SAMPLE,
)
from tests.metagpt.actions.mock_json import DESIGN, PRD
@pytest.mark.asyncio
async def test_design_api(context):
async def test_task(context):
await context.repo.docs.prd.save("1.txt", content=str(PRD))
await context.repo.docs.system_design.save("1.txt", content=str(DESIGN))
logger.info(context.git_repo)
@ -26,3 +32,19 @@ async def test_design_api(context):
logger.info(result)
assert result
@pytest.mark.asyncio
async def test_refined_task(context):
await context.repo.docs.prd.save("2.txt", content=str(REFINED_PRD_JSON))
await context.repo.docs.system_design.save("2.txt", content=str(REFINED_DESIGN_JSON))
await context.repo.docs.task.save("2.txt", content=TASK_SAMPLE)
logger.info(context.git_repo)
action = WriteTasks(context=context, llm=LLM())
result = await action.run(Message(content="", instruct_content=None))
logger.info(result)
assert result

View file

@ -10,13 +10,14 @@ from openai._models import BaseModel
from metagpt.actions.action_node import ActionNode, dict_to_markdown
from metagpt.actions.project_management import NEW_REQ_TEMPLATE
from metagpt.actions.project_management_an import REFINED_PM_NODE
from metagpt.actions.project_management_an import PM_NODE, REFINED_PM_NODE
from metagpt.llm import LLM
from tests.data.incremental_dev_project.mock import (
REFINED_DESIGN_JSON,
REFINED_TASKS_JSON,
TASKS_SAMPLE,
REFINED_TASK_JSON,
TASK_SAMPLE,
)
from tests.metagpt.actions.mock_json import TASK
@pytest.fixture()
@ -24,20 +25,40 @@ def llm():
return LLM()
def mock_refined_tasks_json():
return REFINED_TASKS_JSON
def mock_refined_task_json():
return REFINED_TASK_JSON
def mock_task_json():
return TASK
@pytest.mark.asyncio
async def test_project_management_an(mocker):
root = ActionNode.from_children(
"ProjectManagement", [ActionNode(key="", expected_type=str, instruction="", example="")]
)
root.instruct_content = BaseModel()
root.instruct_content.model_dump = mock_task_json
mocker.patch("metagpt.actions.project_management_an.PM_NODE.fill", return_value=root)
node = await PM_NODE.fill(dict_to_markdown(REFINED_DESIGN_JSON), llm)
assert "Logic Analysis" in node.instruct_content.model_dump()
assert "Task list" in node.instruct_content.model_dump()
assert "Shared Knowledge" in node.instruct_content.model_dump()
@pytest.mark.asyncio
async def test_project_management_an_inc(mocker):
root = ActionNode.from_children(
"RefinedProjectManagement", [ActionNode(key="", expected_type=str, instruction="", example="")]
)
root.instruct_content = BaseModel()
root.instruct_content.model_dump = mock_refined_tasks_json
root.instruct_content.model_dump = mock_refined_task_json
mocker.patch("metagpt.actions.project_management_an.REFINED_PM_NODE.fill", return_value=root)
prompt = NEW_REQ_TEMPLATE.format(old_task=TASKS_SAMPLE, context=dict_to_markdown(REFINED_DESIGN_JSON))
prompt = NEW_REQ_TEMPLATE.format(old_task=TASK_SAMPLE, context=dict_to_markdown(REFINED_DESIGN_JSON))
node = await REFINED_PM_NODE.fill(prompt, llm)
assert "Refined Logic Analysis" in node.instruct_content.model_dump()

View file

@ -23,6 +23,8 @@ async def test_rebuild(context):
context=context,
)
await action.run()
rows = await action.graph_db.select()
assert rows
assert context.repo.docs.graph_repo.changed_files
@ -45,6 +47,12 @@ def test_align_path(path, direction, diff, want):
("/Users/x/github/MetaGPT/metagpt", "/Users/x/github/MetaGPT/metagpt", "=", "."),
("/Users/x/github/MetaGPT", "/Users/x/github/MetaGPT/metagpt", "-", "metagpt"),
("/Users/x/github/MetaGPT/metagpt", "/Users/x/github/MetaGPT", "+", "metagpt"),
(
"/Users/x/github/MetaGPT-env/lib/python3.9/site-packages/moviepy",
"/Users/x/github/MetaGPT-env/lib/python3.9/site-packages/",
"+",
"moviepy",
),
],
)
def test_diff_path(path_root, package_root, want_direction, want_diff):

View file

@ -4,6 +4,7 @@
@Time : 2024/1/4
@Author : mashenquan
@File : test_rebuild_sequence_view.py
@Desc : Unit tests for reconstructing the sequence diagram from a source code project.
"""
from pathlib import Path
@ -14,25 +15,40 @@ from metagpt.const import GRAPH_REPO_FILE_REPO
from metagpt.llm import LLM
from metagpt.utils.common import aread
from metagpt.utils.git_repository import ChangeType
from metagpt.utils.graph_repository import SPO
@pytest.mark.asyncio
@pytest.mark.skip
async def test_rebuild(context):
async def test_rebuild(context, mocker):
# Mock
data = await aread(filename=Path(__file__).parent / "../../data/graph_db/networkx.json")
data = await aread(filename=Path(__file__).parent / "../../data/graph_db/networkx.class_view.json")
graph_db_filename = Path(context.repo.workdir.name).with_suffix(".json")
await context.repo.docs.graph_repo.save(filename=str(graph_db_filename), content=data)
context.git_repo.add_change({f"{GRAPH_REPO_FILE_REPO}/{graph_db_filename}": ChangeType.UNTRACTED})
context.git_repo.commit("commit1")
# mock_spo = SPO(
# subject="metagpt/startup.py:__name__:__main__",
# predicate="has_page_info",
# object_='{"lineno":78,"end_lineno":79,"type_name":"ast.If","tokens":["__name__","__main__"],"properties":{}}',
# )
mock_spo = SPO(
subject="metagpt/management/skill_manager.py:__name__:__main__",
predicate="has_page_info",
object_='{"lineno":113,"end_lineno":116,"type_name":"ast.If","tokens":["__name__","__main__"],"properties":{}}',
)
mocker.patch.object(RebuildSequenceView, "_search_main_entry", return_value=[mock_spo])
action = RebuildSequenceView(
name="RedBean",
i_context=str(Path(__file__).parent.parent.parent.parent / "metagpt"),
i_context=str(
Path(__file__).parent.parent.parent.parent / "metagpt/management/skill_manager.py:__name__:__main__"
),
llm=LLM(),
context=context,
)
await action.run()
rows = await action.graph_db.select()
assert rows
assert context.repo.docs.graph_repo.changed_files

View file

@ -6,14 +6,13 @@
@File : test_summarize_code.py
@Modifiled By: mashenquan, 2023-12-6. Unit test for summarize_code.py
"""
import uuid
from pathlib import Path
import pytest
from metagpt.actions.summarize_code import SummarizeCode
from metagpt.logs import logger
from metagpt.schema import CodeSummarizeContext
from tests.mock.mock_llm import MockLLM
DESIGN_CONTENT = """
{"Implementation approach": "To develop this snake game, we will use the Python language and choose the Pygame library. Pygame is an open-source Python module collection specifically designed for writing video games. It provides functionalities such as displaying images and playing sounds, making it suitable for creating intuitive and responsive user interfaces. We will ensure efficient game logic to prevent any delays during gameplay. The scoring system will be simple, with the snake gaining points for each food it eats. We will use Pygame's event handling system to implement pause and resume functionality, as well as high-score tracking. The difficulty will increase by speeding up the snake's movement. In the initial version, we will focus on single-player mode and consider adding multiplayer mode and customizable skins in future updates. Based on the new requirement, we will also add a moving obstacle that appears randomly. If the snake eats this obstacle, the game will end. If the snake does not eat the obstacle, it will disappear after 5 seconds. For this, we need to add mechanisms for obstacle generation, movement, and disappearance in the game logic.", "Project_name": "snake_game", "File list": ["main.py", "game.py", "snake.py", "food.py", "obstacle.py", "scoreboard.py", "constants.py", "assets/styles.css", "assets/index.html"], "Data structures and interfaces": "```mermaid\n classDiagram\n class Game{\n +int score\n +int speed\n +bool game_over\n +bool paused\n +Snake snake\n +Food food\n +Obstacle obstacle\n +Scoreboard scoreboard\n +start_game() void\n +pause_game() void\n +resume_game() void\n +end_game() void\n +increase_difficulty() void\n +update() void\n +render() void\n Game()\n }\n class Snake{\n +list body_parts\n +str direction\n +bool grow\n +move() void\n +grow() void\n +check_collision() bool\n Snake()\n }\n class Food{\n +tuple position\n +spawn() void\n Food()\n }\n class Obstacle{\n +tuple position\n +int lifetime\n +bool active\n +spawn() void\n +move() void\n +check_collision() bool\n +disappear() void\n Obstacle()\n }\n class Scoreboard{\n +int high_score\n +update_score(int) void\n +reset_score() void\n +load_high_score() void\n +save_high_score() void\n Scoreboard()\n }\n class Constants{\n }\n Game \"1\" -- \"1\" Snake: has\n Game \"1\" -- \"1\" Food: has\n Game \"1\" -- \"1\" Obstacle: has\n Game \"1\" -- \"1\" Scoreboard: has\n ```", "Program call flow": "```sequenceDiagram\n participant M as Main\n participant G as Game\n participant S as Snake\n participant F as Food\n participant O as Obstacle\n participant SB as Scoreboard\n M->>G: start_game()\n loop game loop\n G->>S: move()\n G->>S: check_collision()\n G->>F: spawn()\n G->>O: spawn()\n G->>O: move()\n G->>O: check_collision()\n G->>O: disappear()\n G->>SB: update_score(score)\n G->>G: update()\n G->>G: render()\n alt if paused\n M->>G: pause_game()\n M->>G: resume_game()\n end\n alt if game_over\n G->>M: end_game()\n end\n end\n```", "Anything UNCLEAR": "There is no need for further clarification as the requirements are already clear."}
@ -175,12 +174,87 @@ class Snake:
"""
mock_rsp = """
```mermaid
classDiagram
class Game{
+int score
+int level
+Snake snake
+Food food
+start_game() void
+initialize_game() void
+game_loop() void
+update() void
+draw() void
+handle_events() void
+check_collision() void
+increase_score() void
+increase_level() void
+game_over() void
Game()
}
class Snake{
+list body
+tuple direction
+move() void
+change_direction(direction: str) void
+grow() void
+get_head() tuple
+get_body() list
Snake()
}
class Food{
+tuple position
+generate() void
+get_position() tuple
Food()
}
Game "1" -- "1" Snake: has
Game "1" -- "1" Food: has
```
```sequenceDiagram
participant M as Main
participant G as Game
participant S as Snake
participant F as Food
M->>G: start_game()
G->>G: initialize_game()
G->>G: game_loop()
G->>S: move()
G->>S: change_direction()
G->>S: grow()
G->>F: generate()
S->>S: move()
S->>S: change_direction()
S->>S: grow()
F->>F: generate()
```
## Summary
The code consists of the main game logic, including the Game, Snake, and Food classes. The game loop is responsible for updating and drawing the game elements, handling events, checking collisions, and managing the game state. The Snake class handles the movement, growth, and direction changes of the snake, while the Food class is responsible for generating and tracking the position of food items.
## TODOs
- Modify 'game.py' to add the implementation of obstacle handling and interaction with the game loop.
- Implement 'obstacle.py' to include the methods for spawning, moving, and disappearing of obstacles, as well as collision detection with the snake.
- Update 'main.py' to initialize the obstacle and incorporate it into the game loop.
- Update the mermaid call flow diagram to include the interaction with the obstacle.
```python
{
"files_to_modify": {
"game.py": "Add obstacle handling and interaction with the game loop",
"obstacle.py": "Implement obstacle class with necessary methods",
"main.py": "Initialize the obstacle and incorporate it into the game loop"
}
}
```
"""
@pytest.mark.asyncio
async def test_summarize_code(context):
git_dir = Path(__file__).parent / f"unittest/{uuid.uuid4().hex}"
git_dir.mkdir(parents=True, exist_ok=True)
async def test_summarize_code(context, mocker):
context.src_workspace = context.git_repo.workdir / "src"
await context.repo.docs.system_design.save(filename="1.json", content=DESIGN_CONTENT)
await context.repo.docs.task.save(filename="1.json", content=TASK_CONTENT)
@ -189,6 +263,7 @@ async def test_summarize_code(context):
await context.repo.srcs.save(filename="game.py", content=GAME_PY)
await context.repo.srcs.save(filename="main.py", content=MAIN_PY)
await context.repo.srcs.save(filename="snake.py", content=SNAKE_PY)
mocker.patch.object(MockLLM, "_mock_rsp", return_value=mock_rsp)
all_files = context.repo.srcs.all_files
summarization_context = CodeSummarizeContext(

View file

@ -6,7 +6,7 @@
@File : test_write_code.py
@Modifiled By: mashenquan, 2023-12-6. According to RFC 135
"""
import json
from pathlib import Path
import pytest
@ -14,10 +14,27 @@ import pytest
from metagpt.actions.write_code import WriteCode
from metagpt.logs import logger
from metagpt.schema import CodingContext, Document
from metagpt.utils.common import aread
from metagpt.utils.common import CodeParser, aread
from tests.data.incremental_dev_project.mock import (
CODE_PLAN_AND_CHANGE_SAMPLE,
REFINED_CODE_INPUT_SAMPLE,
REFINED_DESIGN_JSON,
REFINED_TASK_JSON,
)
from tests.metagpt.actions.mock_markdown import TASKS_2, WRITE_CODE_PROMPT_SAMPLE
def setup_inc_workdir(context, inc: bool = False):
"""setup incremental workdir for testing"""
context.src_workspace = context.git_repo.workdir / "src"
if inc:
context.config.inc = inc
context.repo.old_workspace = context.repo.git_repo.workdir / "old"
context.config.project_path = "old"
return context
@pytest.mark.asyncio
async def test_write_code(context):
# Prerequisites
@ -81,5 +98,66 @@ async def test_write_code_deps(context):
assert rsp.code_doc.content
@pytest.mark.asyncio
async def test_write_refined_code(context, git_dir):
# Prerequisites
context = setup_inc_workdir(context, inc=True)
await context.repo.docs.system_design.save(filename="1.json", content=json.dumps(REFINED_DESIGN_JSON))
await context.repo.docs.task.save(filename="1.json", content=json.dumps(REFINED_TASK_JSON))
await context.repo.docs.code_plan_and_change.save(
filename="1.json", content=json.dumps(CODE_PLAN_AND_CHANGE_SAMPLE)
)
# old_workspace contains the legacy code
await context.repo.with_src_path(context.repo.old_workspace).srcs.save(
filename="game.py", content=CodeParser.parse_code(block="", text=REFINED_CODE_INPUT_SAMPLE)
)
ccontext = CodingContext(
filename="game.py",
design_doc=await context.repo.docs.system_design.get(filename="1.json"),
task_doc=await context.repo.docs.task.get(filename="1.json"),
code_plan_and_change_doc=await context.repo.docs.code_plan_and_change.get(filename="1.json"),
code_doc=Document(filename="game.py", content="", root_path="src"),
)
coding_doc = Document(root_path="src", filename="game.py", content=ccontext.json())
action = WriteCode(i_context=coding_doc, context=context)
rsp = await action.run()
assert rsp
assert rsp.code_doc.content
@pytest.mark.asyncio
async def test_get_codes(context):
# Prerequisites
context = setup_inc_workdir(context, inc=True)
for filename in ["game.py", "ui.py"]:
await context.repo.with_src_path(context.src_workspace).srcs.save(
filename=filename, content=f"# {filename}\nnew code ..."
)
await context.repo.with_src_path(context.repo.old_workspace).srcs.save(
filename=filename, content=f"# {filename}\nlegacy code ..."
)
await context.repo.with_src_path(context.repo.old_workspace).srcs.save(
filename="gui.py", content="# gui.py\nlegacy code ..."
)
await context.repo.with_src_path(context.repo.old_workspace).srcs.save(
filename="main.py", content='# main.py\nif __name__ == "__main__":\n main()'
)
task_doc = Document(filename="1.json", content=json.dumps(REFINED_TASK_JSON))
context.repo = context.repo.with_src_path(context.src_workspace)
# Ready to write gui.py
codes = await WriteCode.get_codes(task_doc=task_doc, exclude="gui.py", project_repo=context.repo)
codes_inc = await WriteCode.get_codes(task_doc=task_doc, exclude="gui.py", project_repo=context.repo, use_inc=True)
logger.info(codes)
logger.info(codes_inc)
assert codes
assert codes_inc
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -5,6 +5,8 @@
@Author : mannaandpoem
@File : test_write_code_plan_and_change_an.py
"""
import json
import pytest
from openai._models import BaseModel
@ -14,15 +16,21 @@ from metagpt.actions.write_code_plan_and_change_an import (
REFINED_TEMPLATE,
WriteCodePlanAndChange,
)
from metagpt.logs import logger
from metagpt.schema import CodePlanAndChangeContext
from metagpt.utils.common import CodeParser
from tests.data.incremental_dev_project.mock import (
CODE_PLAN_AND_CHANGE_SAMPLE,
DESIGN_SAMPLE,
NEW_REQUIREMENT_SAMPLE,
REFINED_CODE_INPUT_SAMPLE,
REFINED_CODE_SAMPLE,
TASKS_SAMPLE,
REFINED_DESIGN_JSON,
REFINED_PRD_JSON,
REFINED_TASK_JSON,
TASK_SAMPLE,
)
from tests.metagpt.actions.test_write_code import setup_inc_workdir
def mock_code_plan_and_change():
@ -30,27 +38,35 @@ def mock_code_plan_and_change():
@pytest.mark.asyncio
async def test_write_code_plan_and_change_an(mocker):
async def test_write_code_plan_and_change_an(mocker, context, git_dir):
context = setup_inc_workdir(context, inc=True)
await context.repo.docs.prd.save(filename="2.json", content=json.dumps(REFINED_PRD_JSON))
await context.repo.docs.system_design.save(filename="2.json", content=json.dumps(REFINED_DESIGN_JSON))
await context.repo.docs.task.save(filename="2.json", content=json.dumps(REFINED_TASK_JSON))
await context.repo.with_src_path(context.repo.old_workspace).srcs.save(
filename="game.py", content=CodeParser.parse_code(block="", text=REFINED_CODE_INPUT_SAMPLE)
)
root = ActionNode.from_children(
"WriteCodePlanAndChange", [ActionNode(key="", expected_type=str, instruction="", example="")]
)
root.instruct_content = BaseModel()
root.instruct_content.model_dump = mock_code_plan_and_change
mocker.patch("metagpt.actions.write_code_plan_and_change_an.WriteCodePlanAndChange.run", return_value=root)
requirement = "New requirement"
prd_filename = "prd.md"
design_filename = "design.md"
task_filename = "task.md"
code_plan_and_change_context = CodePlanAndChangeContext(
requirement=requirement,
prd_filename=prd_filename,
design_filename=design_filename,
task_filename=task_filename,
mocker.patch(
"metagpt.actions.write_code_plan_and_change_an.WRITE_CODE_PLAN_AND_CHANGE_NODE.fill", return_value=root
)
node = await WriteCodePlanAndChange(i_context=code_plan_and_change_context).run()
assert "Code Plan And Change" in node.instruct_content.model_dump()
code_plan_and_change_context = CodePlanAndChangeContext(
requirement="New requirement",
prd_filename="2.json",
design_filename="2.json",
task_filename="2.json",
)
node = await WriteCodePlanAndChange(i_context=code_plan_and_change_context, context=context).run()
assert "Development Plan" in node.instruct_content.model_dump()
assert "Incremental Change" in node.instruct_content.model_dump()
@pytest.mark.asyncio
@ -60,7 +76,7 @@ async def test_refine_code(mocker):
user_requirement=NEW_REQUIREMENT_SAMPLE,
code_plan_and_change=CODE_PLAN_AND_CHANGE_SAMPLE,
design=DESIGN_SAMPLE,
task=TASKS_SAMPLE,
task=TASK_SAMPLE,
code=REFINED_CODE_INPUT_SAMPLE,
logs="",
feedback="",
@ -69,3 +85,25 @@ async def test_refine_code(mocker):
)
code = await WriteCode().write_code(prompt=prompt)
assert "def" in code
@pytest.mark.asyncio
async def test_get_old_code(context, git_dir):
context = setup_inc_workdir(context, inc=True)
await context.repo.with_src_path(context.repo.old_workspace).srcs.save(
filename="game.py", content=REFINED_CODE_INPUT_SAMPLE
)
code_plan_and_change_context = CodePlanAndChangeContext(
requirement="New requirement",
prd_filename="1.json",
design_filename="1.json",
task_filename="1.json",
)
action = WriteCodePlanAndChange(context=context, i_context=code_plan_and_change_context)
old_codes = await action.get_old_codes()
logger.info(old_codes)
assert "def" in old_codes
assert "class" in old_codes

View file

@ -32,5 +32,28 @@ def add(a, b):
print(f"输出内容: {captured.out}")
@pytest.mark.asyncio
async def test_write_code_review_inc(capfd, context):
context.src_workspace = context.repo.workdir / "srcs"
context.config.inc = True
code = """
def add(a, b):
return a +
"""
code_plan_and_change = """
def add(a, b):
- return a +
+ return a + b
"""
coding_context = CodingContext(
filename="math.py",
design_doc=Document(content="编写一个从a加b的函数返回a+b"),
code_doc=Document(content=code),
code_plan_and_change_doc=Document(content=code_plan_and_change),
)
await WriteCodeReview(i_context=coding_context, context=context).run()
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -6,6 +6,7 @@
@File : test_write_prd.py
@Modified By: mashenquan, 2023-11-1. According to Chapter 2.2.1 and 2.2.2 of RFC 116, replace `handle` with `run`.
"""
import pytest
from metagpt.actions import UserRequirement, WritePRD
@ -15,6 +16,8 @@ from metagpt.roles.product_manager import ProductManager
from metagpt.roles.role import RoleReactMode
from metagpt.schema import Message
from metagpt.utils.common import any_to_str
from tests.data.incremental_dev_project.mock import NEW_REQUIREMENT_SAMPLE, PRD_SAMPLE
from tests.metagpt.actions.test_write_code import setup_inc_workdir
@pytest.mark.asyncio
@ -34,5 +37,41 @@ async def test_write_prd(new_filename, context):
assert product_manager.context.repo.docs.prd.changed_files
@pytest.mark.asyncio
async def test_write_prd_inc(new_filename, context, git_dir):
context = setup_inc_workdir(context, inc=True)
await context.repo.docs.prd.save("1.txt", PRD_SAMPLE)
await context.repo.docs.save(filename=REQUIREMENT_FILENAME, content=NEW_REQUIREMENT_SAMPLE)
action = WritePRD(context=context)
prd = await action.run(Message(content=NEW_REQUIREMENT_SAMPLE, instruct_content=None))
logger.info(NEW_REQUIREMENT_SAMPLE)
logger.info(prd)
# Assert the prd is not None or empty
assert prd is not None
assert prd.content != ""
assert "Refined Requirements" in prd.content
@pytest.mark.asyncio
async def test_fix_debug(new_filename, context, git_dir):
context.src_workspace = context.git_repo.workdir / context.git_repo.workdir.name
await context.repo.with_src_path(context.src_workspace).srcs.save(
filename="main.py", content='if __name__ == "__main__":\nmain()'
)
requirements = "Please fix the bug in the code."
await context.repo.docs.save(filename=REQUIREMENT_FILENAME, content=requirements)
action = WritePRD(context=context)
prd = await action.run(Message(content=requirements, instruct_content=None))
logger.info(prd)
# Assert the prd is not None or empty
assert prd is not None
assert prd.content != ""
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -12,7 +12,7 @@ from metagpt.document_store.chromadb_store import ChromaStore
def test_chroma_store():
"""FIXMEchroma使用感觉很诡异一用Python就挂测试用例里也是"""
# 创建 ChromaStore 实例,使用 'sample_collection' 集合
document_store = ChromaStore("sample_collection_1")
document_store = ChromaStore("sample_collection_1", get_or_create=True)
# 使用 write 方法添加多个文档
document_store.write(

View file

@ -6,6 +6,7 @@
@File : test_faiss_store.py
"""
import numpy as np
import pytest
from metagpt.const import EXAMPLE_PATH
@ -14,9 +15,24 @@ from metagpt.logs import logger
from metagpt.roles import Sales
def mock_openai_embed_documents(self, texts: list[str], show_progress: bool = False) -> list[list[float]]:
num = len(texts)
embeds = np.random.randint(1, 100, size=(num, 1536)) # 1536: openai embedding dim
embeds = (embeds - embeds.mean(axis=0)) / embeds.std(axis=0)
return embeds.tolist()
def mock_openai_embed_document(self, text: str) -> list[float]:
embeds = mock_openai_embed_documents(self, [text])
return embeds[0]
@pytest.mark.asyncio
async def test_search_json():
store = FaissStore(EXAMPLE_PATH / "example.json")
async def test_search_json(mocker):
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embeddings", mock_openai_embed_documents)
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embedding", mock_openai_embed_document)
store = FaissStore(EXAMPLE_PATH / "data/search_kb/example.json")
role = Sales(profile="Sales", store=store)
query = "Which facial cleanser is good for oily skin?"
result = await role.run(query)
@ -24,8 +40,11 @@ async def test_search_json():
@pytest.mark.asyncio
async def test_search_xlsx():
store = FaissStore(EXAMPLE_PATH / "example.xlsx")
async def test_search_xlsx(mocker):
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embeddings", mock_openai_embed_documents)
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embedding", mock_openai_embed_document)
store = FaissStore(EXAMPLE_PATH / "data/search_kb/example.xlsx", meta_col="Answer", content_col="Question")
role = Sales(profile="Sales", store=store)
query = "Which facial cleanser is good for oily skin?"
result = await role.run(query)
@ -33,8 +52,11 @@ async def test_search_xlsx():
@pytest.mark.asyncio
async def test_write():
store = FaissStore(EXAMPLE_PATH / "example.xlsx", meta_col="Answer", content_col="Question")
async def test_write(mocker):
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embeddings", mock_openai_embed_documents)
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embedding", mock_openai_embed_document)
store = FaissStore(EXAMPLE_PATH / "data/search_kb/example.xlsx", meta_col="Answer", content_col="Question")
_faiss_store = store.write()
assert _faiss_store.docstore
assert _faiss_store.index
assert _faiss_store.storage_context.docstore
assert _faiss_store.storage_context.vector_store.client

View file

@ -1,14 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : the unittest of MincraftExtEnv
from metagpt.environment.mincraft_env.const import MC_CKPT_DIR
from metagpt.environment.mincraft_env.mincraft_ext_env import MincraftExtEnv
def test_mincraft_ext_env():
ext_env = MincraftExtEnv()
assert ext_env.server, f"{ext_env.server_host}:{ext_env.server_port}"
assert MC_CKPT_DIR.joinpath("skill/code").exists()
assert ext_env.warm_up.get("optional_inventory_items") == 7

View file

@ -0,0 +1,14 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : the unittest of MinecraftExtEnv
from metagpt.environment.minecraft_env.const import MC_CKPT_DIR
from metagpt.environment.minecraft_env.minecraft_ext_env import MinecraftExtEnv
def test_minecraft_ext_env():
ext_env = MinecraftExtEnv()
assert ext_env.server, f"{ext_env.server_host}:{ext_env.server_port}"
assert MC_CKPT_DIR.joinpath("skill/code").exists()
assert ext_env.warm_up.get("optional_inventory_items") == 7

View file

@ -0,0 +1,42 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc :
import numpy as np
dim = 1536 # openai embedding dim
embed_zeros_arrr = np.zeros(shape=[1, dim]).tolist()
embed_ones_arrr = np.ones(shape=[1, dim]).tolist()
text_embed_arr = [
{"text": "Write a cli snake game", "embed": embed_zeros_arrr}, # mock data, same as below
{"text": "Write a game of cli snake", "embed": embed_zeros_arrr},
{"text": "Write a 2048 web game", "embed": embed_ones_arrr},
{"text": "Write a Battle City", "embed": embed_ones_arrr},
{
"text": "The user has requested the creation of a command-line interface (CLI) snake game",
"embed": embed_zeros_arrr,
},
{"text": "The request is command-line interface (CLI) snake game", "embed": embed_zeros_arrr},
{
"text": "Incorporate basic features of a snake game such as scoring and increasing difficulty",
"embed": embed_ones_arrr,
},
]
text_idx_dict = {item["text"]: idx for idx, item in enumerate(text_embed_arr)}
def mock_openai_embed_documents(self, texts: list[str], show_progress: bool = False) -> list[list[float]]:
idx = text_idx_dict.get(texts[0])
embed = text_embed_arr[idx].get("embed")
return embed
def mock_openai_embed_document(self, text: str) -> list[float]:
embeds = mock_openai_embed_documents(self, [text])
return embeds[0]
async def mock_openai_aembed_document(self, text: str) -> list[float]:
return mock_openai_embed_document(self, text)

View file

@ -4,20 +4,29 @@
@Desc : unittest of `metagpt/memory/longterm_memory.py`
"""
import os
import pytest
from metagpt.actions import UserRequirement
from metagpt.config2 import config
from metagpt.memory.longterm_memory import LongTermMemory
from metagpt.roles.role import RoleContext
from metagpt.schema import Message
os.environ.setdefault("OPENAI_API_KEY", config.get_openai_llm().api_key)
from tests.metagpt.memory.mock_text_embed import (
mock_openai_aembed_document,
mock_openai_embed_document,
mock_openai_embed_documents,
text_embed_arr,
)
def test_ltm_search():
@pytest.mark.asyncio
async def test_ltm_search(mocker):
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embeddings", mock_openai_embed_documents)
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embedding", mock_openai_embed_document)
mocker.patch(
"llama_index.embeddings.openai.base.OpenAIEmbedding._aget_query_embedding", mock_openai_aembed_document
)
role_id = "UTUserLtm(Product Manager)"
from metagpt.environment import Environment
@ -27,41 +36,26 @@ def test_ltm_search():
ltm = LongTermMemory()
ltm.recover_memory(role_id, rc)
idea = "Write a cli snake game"
idea = text_embed_arr[0].get("text", "Write a cli snake game")
message = Message(role="User", content=idea, cause_by=UserRequirement)
news = ltm.find_news([message])
news = await ltm.find_news([message])
assert len(news) == 1
ltm.add(message)
sim_idea = "Write a game of cli snake"
sim_idea = text_embed_arr[1].get("text", "Write a game of cli snake")
sim_message = Message(role="User", content=sim_idea, cause_by=UserRequirement)
news = ltm.find_news([sim_message])
news = await ltm.find_news([sim_message])
assert len(news) == 0
ltm.add(sim_message)
new_idea = "Write a 2048 web game"
new_idea = text_embed_arr[2].get("text", "Write a 2048 web game")
new_message = Message(role="User", content=new_idea, cause_by=UserRequirement)
news = ltm.find_news([new_message])
news = await ltm.find_news([new_message])
assert len(news) == 1
ltm.add(new_message)
# restore from local index
ltm_new = LongTermMemory()
ltm_new.recover_memory(role_id, rc)
news = ltm_new.find_news([message])
assert len(news) == 0
ltm_new.recover_memory(role_id, rc)
news = ltm_new.find_news([sim_message])
assert len(news) == 0
new_idea = "Write a Battle City"
new_message = Message(role="User", content=new_idea, cause_by=UserRequirement)
news = ltm_new.find_news([new_message])
assert len(news) == 1
ltm_new.clear()
ltm.clear()
if __name__ == "__main__":

View file

@ -4,56 +4,75 @@
@Desc : the unittests of metagpt/memory/memory_storage.py
"""
import os
import shutil
from pathlib import Path
from typing import List
import pytest
from metagpt.actions import UserRequirement, WritePRD
from metagpt.actions.action_node import ActionNode
from metagpt.config2 import config
from metagpt.const import DATA_PATH
from metagpt.memory.memory_storage import MemoryStorage
from metagpt.schema import Message
os.environ.setdefault("OPENAI_API_KEY", config.get_openai_llm().api_key)
from tests.metagpt.memory.mock_text_embed import (
mock_openai_aembed_document,
mock_openai_embed_document,
mock_openai_embed_documents,
text_embed_arr,
)
def test_idea_message():
idea = "Write a cli snake game"
@pytest.mark.asyncio
async def test_idea_message(mocker):
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embeddings", mock_openai_embed_documents)
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embedding", mock_openai_embed_document)
mocker.patch(
"llama_index.embeddings.openai.base.OpenAIEmbedding._aget_query_embedding", mock_openai_aembed_document
)
idea = text_embed_arr[0].get("text", "Write a cli snake game")
role_id = "UTUser1(Product Manager)"
message = Message(role="User", content=idea, cause_by=UserRequirement)
shutil.rmtree(Path(DATA_PATH / f"role_mem/{role_id}/"), ignore_errors=True)
memory_storage: MemoryStorage = MemoryStorage()
messages = memory_storage.recover_memory(role_id)
assert len(messages) == 0
memory_storage.recover_memory(role_id)
memory_storage.add(message)
assert memory_storage.is_initialized is True
sim_idea = "Write a game of cli snake"
sim_idea = text_embed_arr[1].get("text", "Write a game of cli snake")
sim_message = Message(role="User", content=sim_idea, cause_by=UserRequirement)
new_messages = memory_storage.search_dissimilar(sim_message)
assert len(new_messages) == 0 # similar, return []
new_messages = await memory_storage.search_similar(sim_message)
assert len(new_messages) == 1 # similar, return []
new_idea = "Write a 2048 web game"
new_idea = text_embed_arr[2].get("text", "Write a 2048 web game")
new_message = Message(role="User", content=new_idea, cause_by=UserRequirement)
new_messages = memory_storage.search_dissimilar(new_message)
assert new_messages[0].content == message.content
new_messages = await memory_storage.search_similar(new_message)
assert len(new_messages) == 0
memory_storage.clean()
assert memory_storage.is_initialized is False
def test_actionout_message():
@pytest.mark.asyncio
async def test_actionout_message(mocker):
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embeddings", mock_openai_embed_documents)
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embedding", mock_openai_embed_document)
mocker.patch(
"llama_index.embeddings.openai.base.OpenAIEmbedding._aget_query_embedding", mock_openai_aembed_document
)
out_mapping = {"field1": (str, ...), "field2": (List[str], ...)}
out_data = {"field1": "field1 value", "field2": ["field2 value1", "field2 value2"]}
ic_obj = ActionNode.create_model_class("prd", out_mapping)
role_id = "UTUser2(Architect)"
content = "The user has requested the creation of a command-line interface (CLI) snake game"
content = text_embed_arr[4].get(
"text", "The user has requested the creation of a command-line interface (CLI) snake game"
)
message = Message(
content=content, instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD
) # WritePRD as test action
@ -61,21 +80,22 @@ def test_actionout_message():
shutil.rmtree(Path(DATA_PATH / f"role_mem/{role_id}/"), ignore_errors=True)
memory_storage: MemoryStorage = MemoryStorage()
messages = memory_storage.recover_memory(role_id)
assert len(messages) == 0
memory_storage.recover_memory(role_id)
memory_storage.add(message)
assert memory_storage.is_initialized is True
sim_conent = "The request is command-line interface (CLI) snake game"
sim_conent = text_embed_arr[5].get("text", "The request is command-line interface (CLI) snake game")
sim_message = Message(content=sim_conent, instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD)
new_messages = memory_storage.search_dissimilar(sim_message)
assert len(new_messages) == 0 # similar, return []
new_messages = await memory_storage.search_similar(sim_message)
assert len(new_messages) == 1 # similar, return []
new_conent = "Incorporate basic features of a snake game such as scoring and increasing difficulty"
new_conent = text_embed_arr[6].get(
"text", "Incorporate basic features of a snake game such as scoring and increasing difficulty"
)
new_message = Message(content=new_conent, instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD)
new_messages = memory_storage.search_dissimilar(new_message)
assert new_messages[0].content == message.content
new_messages = await memory_storage.search_similar(new_message)
assert len(new_messages) == 0
memory_storage.clean()
assert memory_storage.is_initialized is False

View file

@ -42,3 +42,21 @@ mock_llm_config_zhipu = LLMConfig(
model="mock_zhipu_model",
proxy="http://localhost:8080",
)
mock_llm_config_spark = LLMConfig(
api_type="spark",
app_id="xxx",
api_key="xxx",
api_secret="xxx",
domain="generalv2",
base_url="wss://spark-api.xf-yun.com/v3.1/chat",
)
mock_llm_config_qianfan = LLMConfig(api_type="qianfan", access_key="xxx", secret_key="xxx", model="ERNIE-Bot-turbo")
mock_llm_config_dashscope = LLMConfig(api_type="dashscope", api_key="xxx", model="qwen-max")
mock_llm_config_anthropic = LLMConfig(
api_type="anthropic", api_key="xxx", base_url="https://api.anthropic.com", model="claude-3-opus-20240229"
)

View file

@ -0,0 +1,185 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : default request & response data for provider unittest
from anthropic.types import (
ContentBlock,
ContentBlockDeltaEvent,
Message,
MessageStartEvent,
TextDelta,
)
from anthropic.types import Usage as AnthropicUsage
from dashscope.api_entities.dashscope_response import (
DashScopeAPIResponse,
GenerationOutput,
GenerationResponse,
GenerationUsage,
)
from openai.types.chat.chat_completion import (
ChatCompletion,
ChatCompletionMessage,
Choice,
)
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from openai.types.chat.chat_completion_chunk import Choice as AChoice
from openai.types.chat.chat_completion_chunk import ChoiceDelta
from openai.types.completion_usage import CompletionUsage
from qianfan.resources.typing import QfResponse
from metagpt.provider.base_llm import BaseLLM
prompt = "who are you?"
messages = [{"role": "user", "content": prompt}]
resp_cont_tmpl = "I'm {name}"
default_resp_cont = resp_cont_tmpl.format(name="GPT")
# part of whole ChatCompletion of openai like structure
def get_part_chat_completion(name: str) -> dict:
part_chat_completion = {
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": resp_cont_tmpl.format(name=name),
},
"finish_reason": "stop",
}
],
"usage": {"completion_tokens": 22, "prompt_tokens": 19, "total_tokens": 41},
}
return part_chat_completion
def get_openai_chat_completion(name: str) -> ChatCompletion:
openai_chat_completion = ChatCompletion(
id="cmpl-a6652c1bb181caae8dd19ad8",
model="xx/xxx",
object="chat.completion",
created=1703300855,
choices=[
Choice(
finish_reason="stop",
index=0,
message=ChatCompletionMessage(role="assistant", content=resp_cont_tmpl.format(name=name)),
logprobs=None,
)
],
usage=CompletionUsage(completion_tokens=110, prompt_tokens=92, total_tokens=202),
)
return openai_chat_completion
def get_openai_chat_completion_chunk(name: str, usage_as_dict: bool = False) -> ChatCompletionChunk:
usage = CompletionUsage(completion_tokens=110, prompt_tokens=92, total_tokens=202)
usage = usage if not usage_as_dict else usage.model_dump()
openai_chat_completion_chunk = ChatCompletionChunk(
id="cmpl-a6652c1bb181caae8dd19ad8",
model="xx/xxx",
object="chat.completion.chunk",
created=1703300855,
choices=[
AChoice(
delta=ChoiceDelta(role="assistant", content=resp_cont_tmpl.format(name=name)),
finish_reason="stop",
index=0,
logprobs=None,
)
],
usage=usage,
)
return openai_chat_completion_chunk
# For gemini
gemini_messages = [{"role": "user", "parts": prompt}]
# For QianFan
qf_jsonbody_dict = {
"id": "as-4v1h587fyv",
"object": "chat.completion",
"created": 1695021339,
"result": "",
"is_truncated": False,
"need_clear_history": False,
"usage": {"prompt_tokens": 7, "completion_tokens": 15, "total_tokens": 22},
}
def get_qianfan_response(name: str) -> QfResponse:
qf_jsonbody_dict["result"] = resp_cont_tmpl.format(name=name)
return QfResponse(code=200, body=qf_jsonbody_dict)
# For DashScope
def get_dashscope_response(name: str) -> GenerationResponse:
return GenerationResponse.from_api_response(
DashScopeAPIResponse(
status_code=200,
output=GenerationOutput(
**{
"text": "",
"finish_reason": "",
"choices": [
{
"finish_reason": "stop",
"message": {"role": "assistant", "content": resp_cont_tmpl.format(name=name)},
}
],
}
),
usage=GenerationUsage(**{"input_tokens": 12, "output_tokens": 98, "total_tokens": 110}),
)
)
# For Anthropic
def get_anthropic_response(name: str, stream: bool = False) -> Message:
if stream:
return [
MessageStartEvent(
message=Message(
id="xxx",
model=name,
role="assistant",
type="message",
content=[ContentBlock(text="", type="text")],
usage=AnthropicUsage(input_tokens=10, output_tokens=10),
),
type="message_start",
),
ContentBlockDeltaEvent(
index=0,
delta=TextDelta(text=resp_cont_tmpl.format(name=name), type="text_delta"),
type="content_block_delta",
),
]
else:
return Message(
id="xxx",
model=name,
role="assistant",
type="message",
content=[ContentBlock(text=resp_cont_tmpl.format(name=name), type="text")],
usage=AnthropicUsage(input_tokens=10, output_tokens=10),
)
# For llm general chat functions call
async def llm_general_chat_funcs_test(llm: BaseLLM, prompt: str, messages: list[dict], resp_cont: str):
resp = await llm.aask(prompt, stream=False)
assert resp == resp_cont
resp = await llm.aask(prompt)
assert resp == resp_cont
resp = await llm.acompletion_text(messages, stream=False)
assert resp == resp_cont
resp = await llm.acompletion_text(messages, stream=True)
assert resp == resp_cont

View file

@ -2,31 +2,45 @@
# -*- coding: utf-8 -*-
# @Desc : the unittest of Claude2
import pytest
from anthropic.resources.completions import Completion
from metagpt.provider.anthropic_api import Claude2
from tests.metagpt.provider.mock_llm_config import mock_llm_config
from metagpt.provider.anthropic_api import AnthropicLLM
from tests.metagpt.provider.mock_llm_config import mock_llm_config_anthropic
from tests.metagpt.provider.req_resp_const import (
get_anthropic_response,
llm_general_chat_funcs_test,
messages,
prompt,
resp_cont_tmpl,
)
prompt = "who are you"
resp = "I'am Claude2"
name = "claude-3-opus-20240229"
resp_cont = resp_cont_tmpl.format(name=name)
def mock_anthropic_completions_create(self, model: str, prompt: str, max_tokens_to_sample: int) -> Completion:
return Completion(id="xx", completion=resp, model="claude-2", stop_reason="stop_sequence", type="completion")
async def mock_anthropic_messages_create(
self, messages: list[dict], model: str, stream: bool = True, max_tokens: int = None, system: str = None
) -> Completion:
if stream:
async def aresp_iterator():
resps = get_anthropic_response(name, stream=True)
for resp in resps:
yield resp
async def mock_anthropic_acompletions_create(self, model: str, prompt: str, max_tokens_to_sample: int) -> Completion:
return Completion(id="xx", completion=resp, model="claude-2", stop_reason="stop_sequence", type="completion")
def test_claude2_ask(mocker):
mocker.patch("anthropic.resources.completions.Completions.create", mock_anthropic_completions_create)
assert resp == Claude2(mock_llm_config).ask(prompt)
return aresp_iterator()
else:
return get_anthropic_response(name)
@pytest.mark.asyncio
async def test_claude2_aask(mocker):
mocker.patch("anthropic.resources.completions.AsyncCompletions.create", mock_anthropic_acompletions_create)
assert resp == await Claude2(mock_llm_config).aask(prompt)
async def test_anthropic_acompletion(mocker):
mocker.patch("anthropic.resources.messages.AsyncMessages.create", mock_anthropic_messages_create)
anthropic_llm = AnthropicLLM(mock_llm_config_anthropic)
resp = await anthropic_llm.acompletion(messages)
assert resp.content[0].text == resp_cont
await llm_general_chat_funcs_test(anthropic_llm, prompt, messages, resp_cont)

View file

@ -11,21 +11,13 @@ import pytest
from metagpt.configs.llm_config import LLMConfig
from metagpt.provider.base_llm import BaseLLM
from metagpt.schema import Message
from tests.metagpt.provider.req_resp_const import (
default_resp_cont,
get_part_chat_completion,
prompt,
)
default_chat_resp = {
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "I'am GPT",
},
"finish_reason": "stop",
}
]
}
prompt_msg = "who are you"
resp_content = default_chat_resp["choices"][0]["message"]["content"]
name = "GPT"
class MockBaseLLM(BaseLLM):
@ -33,16 +25,19 @@ class MockBaseLLM(BaseLLM):
pass
def completion(self, messages: list[dict], timeout=3):
return default_chat_resp
return get_part_chat_completion(name)
async def _achat_completion(self, messages: list[dict], timeout=3):
pass
async def acompletion(self, messages: list[dict], timeout=3):
return default_chat_resp
return get_part_chat_completion(name)
async def _achat_completion_stream(self, messages: list[dict], timeout: int = 3) -> str:
pass
async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str:
return resp_content
async def close(self):
return default_chat_resp
return default_resp_cont
def test_base_llm():
@ -86,25 +81,25 @@ def test_base_llm():
choice_text = base_llm.get_choice_text(openai_funccall_resp)
assert choice_text == openai_funccall_resp["choices"][0]["message"]["content"]
# resp = base_llm.ask(prompt_msg)
# assert resp == resp_content
# resp = base_llm.ask(prompt)
# assert resp == default_resp_cont
# resp = base_llm.ask_batch([prompt_msg])
# assert resp == resp_content
# resp = base_llm.ask_batch([prompt])
# assert resp == default_resp_cont
# resp = base_llm.ask_code([prompt_msg])
# assert resp == resp_content
# resp = base_llm.ask_code([prompt])
# assert resp == default_resp_cont
@pytest.mark.asyncio
async def test_async_base_llm():
base_llm = MockBaseLLM()
resp = await base_llm.aask(prompt_msg)
assert resp == resp_content
resp = await base_llm.aask(prompt)
assert resp == default_resp_cont
resp = await base_llm.aask_batch([prompt_msg])
assert resp == resp_content
resp = await base_llm.aask_batch([prompt])
assert resp == default_resp_cont
# resp = await base_llm.aask_code([prompt_msg])
# assert resp == resp_content
# resp = await base_llm.aask_code([prompt])
# assert resp == default_resp_cont

View file

@ -0,0 +1,73 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : the unittest of DashScopeLLM
from typing import AsyncGenerator, Union
import pytest
from dashscope.api_entities.dashscope_response import GenerationResponse
from metagpt.provider.dashscope_api import DashScopeLLM
from tests.metagpt.provider.mock_llm_config import mock_llm_config_dashscope
from tests.metagpt.provider.req_resp_const import (
get_dashscope_response,
llm_general_chat_funcs_test,
messages,
prompt,
resp_cont_tmpl,
)
name = "qwen-max"
resp_cont = resp_cont_tmpl.format(name=name)
@classmethod
def mock_dashscope_call(
cls,
messages: list[dict],
model: str,
api_key: str,
result_format: str,
incremental_output: bool = True,
stream: bool = False,
) -> GenerationResponse:
return get_dashscope_response(name)
@classmethod
async def mock_dashscope_acall(
cls,
messages: list[dict],
model: str,
api_key: str,
result_format: str,
incremental_output: bool = True,
stream: bool = False,
) -> Union[AsyncGenerator[GenerationResponse, None], GenerationResponse]:
resps = [get_dashscope_response(name)]
if stream:
async def aresp_iterator(resps: list[GenerationResponse]):
for resp in resps:
yield resp
return aresp_iterator(resps)
else:
return resps[0]
@pytest.mark.asyncio
async def test_dashscope_acompletion(mocker):
mocker.patch("dashscope.aigc.generation.Generation.call", mock_dashscope_call)
mocker.patch("metagpt.provider.dashscope_api.AGeneration.acall", mock_dashscope_acall)
dashscope_llm = DashScopeLLM(mock_llm_config_dashscope)
resp = dashscope_llm.completion(messages)
assert resp.choices[0]["message"]["content"] == resp_cont
resp = await dashscope_llm.acompletion(messages)
assert resp.choices[0]["message"]["content"] == resp_cont
await llm_general_chat_funcs_test(dashscope_llm, prompt, messages, resp_cont)

View file

@ -1,114 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : the unittest of fireworks api
import pytest
from openai.types.chat.chat_completion import (
ChatCompletion,
ChatCompletionMessage,
Choice,
)
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from openai.types.chat.chat_completion_chunk import Choice as AChoice
from openai.types.chat.chat_completion_chunk import ChoiceDelta
from openai.types.completion_usage import CompletionUsage
from metagpt.provider.fireworks_api import (
MODEL_GRADE_TOKEN_COSTS,
FireworksCostManager,
FireworksLLM,
)
from metagpt.utils.cost_manager import Costs
from tests.metagpt.provider.mock_llm_config import mock_llm_config
resp_content = "I'm fireworks"
default_resp = ChatCompletion(
id="cmpl-a6652c1bb181caae8dd19ad8",
model="accounts/fireworks/models/llama-v2-13b-chat",
object="chat.completion",
created=1703300855,
choices=[
Choice(
finish_reason="stop",
index=0,
message=ChatCompletionMessage(role="assistant", content=resp_content),
logprobs=None,
)
],
usage=CompletionUsage(completion_tokens=110, prompt_tokens=92, total_tokens=202),
)
default_resp_chunk = ChatCompletionChunk(
id=default_resp.id,
model=default_resp.model,
object="chat.completion.chunk",
created=default_resp.created,
choices=[
AChoice(
delta=ChoiceDelta(content=resp_content, role="assistant"),
finish_reason="stop",
index=0,
logprobs=None,
)
],
usage=dict(default_resp.usage),
)
prompt_msg = "who are you"
messages = [{"role": "user", "content": prompt_msg}]
def test_fireworks_costmanager():
cost_manager = FireworksCostManager()
assert MODEL_GRADE_TOKEN_COSTS["-1"] == cost_manager.model_grade_token_costs("test")
assert MODEL_GRADE_TOKEN_COSTS["-1"] == cost_manager.model_grade_token_costs("xxx-81b-chat")
assert MODEL_GRADE_TOKEN_COSTS["16"] == cost_manager.model_grade_token_costs("llama-v2-13b-chat")
assert MODEL_GRADE_TOKEN_COSTS["16"] == cost_manager.model_grade_token_costs("xxx-15.5b-chat")
assert MODEL_GRADE_TOKEN_COSTS["16"] == cost_manager.model_grade_token_costs("xxx-16b-chat")
assert MODEL_GRADE_TOKEN_COSTS["80"] == cost_manager.model_grade_token_costs("xxx-80b-chat")
assert MODEL_GRADE_TOKEN_COSTS["mixtral-8x7b"] == cost_manager.model_grade_token_costs("mixtral-8x7b-chat")
cost_manager.update_cost(prompt_tokens=500000, completion_tokens=500000, model="llama-v2-13b-chat")
assert cost_manager.total_cost == 0.5
async def mock_openai_acompletions_create(self, stream: bool = False, **kwargs) -> ChatCompletionChunk:
if stream:
class Iterator(object):
async def __aiter__(self):
yield default_resp_chunk
return Iterator()
else:
return default_resp
@pytest.mark.asyncio
async def test_fireworks_acompletion(mocker):
mocker.patch("openai.resources.chat.completions.AsyncCompletions.create", mock_openai_acompletions_create)
fireworks_gpt = FireworksLLM(mock_llm_config)
fireworks_gpt.model = "llama-v2-13b-chat"
fireworks_gpt._update_costs(
usage=CompletionUsage(prompt_tokens=500000, completion_tokens=500000, total_tokens=1000000)
)
assert fireworks_gpt.get_costs() == Costs(
total_prompt_tokens=500000, total_completion_tokens=500000, total_cost=0.5, total_budget=0
)
resp = await fireworks_gpt.acompletion(messages)
assert resp.choices[0].message.content in resp_content
resp = await fireworks_gpt.aask(prompt_msg, stream=False)
assert resp == resp_content
resp = await fireworks_gpt.acompletion_text(messages, stream=False)
assert resp == resp_content
resp = await fireworks_gpt.acompletion_text(messages, stream=True)
assert resp == resp_content
resp = await fireworks_gpt.aask(prompt_msg)
assert resp == resp_content

View file

@ -11,6 +11,12 @@ from google.generativeai.types import content_types
from metagpt.provider.google_gemini_api import GeminiLLM
from tests.metagpt.provider.mock_llm_config import mock_llm_config
from tests.metagpt.provider.req_resp_const import (
gemini_messages,
llm_general_chat_funcs_test,
prompt,
resp_cont_tmpl,
)
@dataclass
@ -18,10 +24,8 @@ class MockGeminiResponse(ABC):
text: str
prompt_msg = "who are you"
messages = [{"role": "user", "parts": prompt_msg}]
resp_content = "I'm gemini from google"
default_resp = MockGeminiResponse(text=resp_content)
resp_cont = resp_cont_tmpl.format(name="gemini")
default_resp = MockGeminiResponse(text=resp_cont)
def mock_gemini_count_tokens(self, contents: content_types.ContentsType) -> glm.CountTokensResponse:
@ -60,28 +64,18 @@ async def test_gemini_acompletion(mocker):
mock_gemini_generate_content_async,
)
gemini_gpt = GeminiLLM(mock_llm_config)
gemini_llm = GeminiLLM(mock_llm_config)
assert gemini_gpt._user_msg(prompt_msg) == {"role": "user", "parts": [prompt_msg]}
assert gemini_gpt._assistant_msg(prompt_msg) == {"role": "model", "parts": [prompt_msg]}
assert gemini_llm._user_msg(prompt) == {"role": "user", "parts": [prompt]}
assert gemini_llm._assistant_msg(prompt) == {"role": "model", "parts": [prompt]}
usage = gemini_gpt.get_usage(messages, resp_content)
usage = gemini_llm.get_usage(gemini_messages, resp_cont)
assert usage == {"prompt_tokens": 20, "completion_tokens": 20}
resp = gemini_gpt.completion(messages)
resp = gemini_llm.completion(gemini_messages)
assert resp == default_resp
resp = await gemini_gpt.acompletion(messages)
resp = await gemini_llm.acompletion(gemini_messages)
assert resp.text == default_resp.text
resp = await gemini_gpt.aask(prompt_msg, stream=False)
assert resp == resp_content
resp = await gemini_gpt.acompletion_text(messages, stream=False)
assert resp == resp_content
resp = await gemini_gpt.acompletion_text(messages, stream=True)
assert resp == resp_content
resp = await gemini_gpt.aask(prompt_msg)
assert resp == resp_content
await llm_general_chat_funcs_test(gemini_llm, prompt, gemini_messages, resp_cont)

View file

@ -9,12 +9,15 @@ import pytest
from metagpt.provider.ollama_api import OllamaLLM
from tests.metagpt.provider.mock_llm_config import mock_llm_config
from tests.metagpt.provider.req_resp_const import (
llm_general_chat_funcs_test,
messages,
prompt,
resp_cont_tmpl,
)
prompt_msg = "who are you"
messages = [{"role": "user", "content": prompt_msg}]
resp_content = "I'm ollama"
default_resp = {"message": {"role": "assistant", "content": resp_content}}
resp_cont = resp_cont_tmpl.format(name="ollama")
default_resp = {"message": {"role": "assistant", "content": resp_cont}}
async def mock_ollama_arequest(self, stream: bool = False, **kwargs) -> Tuple[Any, Any, bool]:
@ -41,19 +44,12 @@ async def mock_ollama_arequest(self, stream: bool = False, **kwargs) -> Tuple[An
async def test_gemini_acompletion(mocker):
mocker.patch("metagpt.provider.general_api_requestor.GeneralAPIRequestor.arequest", mock_ollama_arequest)
ollama_gpt = OllamaLLM(mock_llm_config)
ollama_llm = OllamaLLM(mock_llm_config)
resp = await ollama_gpt.acompletion(messages)
resp = await ollama_llm.acompletion(messages)
assert resp["message"]["content"] == default_resp["message"]["content"]
resp = await ollama_gpt.aask(prompt_msg, stream=False)
assert resp == resp_content
resp = await ollama_llm.aask(prompt, stream=False)
assert resp == resp_cont
resp = await ollama_gpt.acompletion_text(messages, stream=False)
assert resp == resp_content
resp = await ollama_gpt.acompletion_text(messages, stream=True)
assert resp == resp_content
resp = await ollama_gpt.aask(prompt_msg)
assert resp == resp_content
await llm_general_chat_funcs_test(ollama_llm, prompt, messages, resp_cont)

View file

@ -1,92 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc :
import pytest
from openai.types.chat.chat_completion import (
ChatCompletion,
ChatCompletionMessage,
Choice,
)
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from openai.types.chat.chat_completion_chunk import Choice as AChoice
from openai.types.chat.chat_completion_chunk import ChoiceDelta
from openai.types.completion_usage import CompletionUsage
from metagpt.provider.open_llm_api import OpenLLM
from metagpt.utils.cost_manager import Costs
from tests.metagpt.provider.mock_llm_config import mock_llm_config
resp_content = "I'm llama2"
default_resp = ChatCompletion(
id="cmpl-a6652c1bb181caae8dd19ad8",
model="llama-v2-13b-chat",
object="chat.completion",
created=1703302755,
choices=[
Choice(
finish_reason="stop",
index=0,
message=ChatCompletionMessage(role="assistant", content=resp_content),
logprobs=None,
)
],
)
default_resp_chunk = ChatCompletionChunk(
id=default_resp.id,
model=default_resp.model,
object="chat.completion.chunk",
created=default_resp.created,
choices=[
AChoice(
delta=ChoiceDelta(content=resp_content, role="assistant"),
finish_reason="stop",
index=0,
logprobs=None,
)
],
)
prompt_msg = "who are you"
messages = [{"role": "user", "content": prompt_msg}]
async def mock_openai_acompletions_create(self, stream: bool = False, **kwargs) -> ChatCompletionChunk:
if stream:
class Iterator(object):
async def __aiter__(self):
yield default_resp_chunk
return Iterator()
else:
return default_resp
@pytest.mark.asyncio
async def test_openllm_acompletion(mocker):
mocker.patch("openai.resources.chat.completions.AsyncCompletions.create", mock_openai_acompletions_create)
openllm_gpt = OpenLLM(mock_llm_config)
openllm_gpt.model = "llama-v2-13b-chat"
openllm_gpt._update_costs(usage=CompletionUsage(prompt_tokens=100, completion_tokens=100, total_tokens=200))
assert openllm_gpt.get_costs() == Costs(
total_prompt_tokens=100, total_completion_tokens=100, total_cost=0, total_budget=0
)
resp = await openllm_gpt.acompletion(messages)
assert resp.choices[0].message.content in resp_content
resp = await openllm_gpt.aask(prompt_msg, stream=False)
assert resp == resp_content
resp = await openllm_gpt.acompletion_text(messages, stream=False)
assert resp == resp_content
resp = await openllm_gpt.acompletion_text(messages, stream=True)
assert resp == resp_content
resp = await openllm_gpt.aask(prompt_msg)
assert resp == resp_content

View file

@ -1,12 +1,11 @@
import json
import pytest
from openai.types.chat import (
ChatCompletion,
ChatCompletionChunk,
ChatCompletionMessage,
ChatCompletionMessageToolCall,
)
from openai.types.chat.chat_completion import Choice
from openai.types.chat.chat_completion import Choice, CompletionUsage
from openai.types.chat.chat_completion_message_tool_call import Function
from PIL import Image
@ -18,6 +17,22 @@ from tests.metagpt.provider.mock_llm_config import (
mock_llm_config,
mock_llm_config_proxy,
)
from tests.metagpt.provider.req_resp_const import (
get_openai_chat_completion,
get_openai_chat_completion_chunk,
llm_general_chat_funcs_test,
messages,
prompt,
resp_cont_tmpl,
)
name = "AI assistant"
resp_cont = resp_cont_tmpl.format(name=name)
default_resp = get_openai_chat_completion(name)
default_resp_chunk = get_openai_chat_completion_chunk(name, usage_as_dict=True)
usage = CompletionUsage(completion_tokens=110, prompt_tokens=92, total_tokens=202)
@pytest.mark.asyncio
@ -106,9 +121,11 @@ class TestOpenAI:
def test_aask_code_json_decode_error(self, json_decode_error):
instance = OpenAILLM(mock_llm_config)
with pytest.raises(json.decoder.JSONDecodeError) as e:
instance.get_choice_function_arguments(json_decode_error)
assert "JSONDecodeError" in str(e)
code = instance.get_choice_function_arguments(json_decode_error)
assert "code" in code
assert "language" in code
assert "hello world" in code["code"]
logger.info(f'code is : {code["code"]}')
@pytest.mark.asyncio
@ -121,3 +138,29 @@ async def test_gen_image():
images: list[Image] = await llm.gen_image(model=model, prompt=prompt, resp_format="b64_json")
assert images[0].size == (1024, 1024)
async def mock_openai_acompletions_create(self, stream: bool = False, **kwargs) -> ChatCompletionChunk:
if stream:
class Iterator(object):
async def __aiter__(self):
yield default_resp_chunk
return Iterator()
else:
return default_resp
@pytest.mark.asyncio
async def test_openai_acompletion(mocker):
mocker.patch("openai.resources.chat.completions.AsyncCompletions.create", mock_openai_acompletions_create)
llm = OpenAILLM(mock_llm_config)
resp = await llm.acompletion(messages)
assert resp.choices[0].finish_reason == "stop"
assert resp.choices[0].message.content == resp_cont
assert resp.usage == usage
await llm_general_chat_funcs_test(llm, prompt, messages, resp_cont)

View file

@ -0,0 +1,56 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : the unittest of qianfan api
from typing import AsyncIterator, Union
import pytest
from qianfan.resources.typing import JsonBody, QfResponse
from metagpt.provider.qianfan_api import QianFanLLM
from tests.metagpt.provider.mock_llm_config import mock_llm_config_qianfan
from tests.metagpt.provider.req_resp_const import (
get_qianfan_response,
llm_general_chat_funcs_test,
messages,
prompt,
resp_cont_tmpl,
)
name = "ERNIE-Bot-turbo"
resp_cont = resp_cont_tmpl.format(name=name)
def mock_qianfan_do(self, messages: list[dict], model: str, stream: bool = False, system: str = None) -> QfResponse:
return get_qianfan_response(name=name)
async def mock_qianfan_ado(
self, messages: list[dict], model: str, stream: bool = True, system: str = None
) -> Union[QfResponse, AsyncIterator[QfResponse]]:
resps = [get_qianfan_response(name=name)]
if stream:
async def aresp_iterator(resps: list[JsonBody]):
for resp in resps:
yield resp
return aresp_iterator(resps)
else:
return resps[0]
@pytest.mark.asyncio
async def test_qianfan_acompletion(mocker):
mocker.patch("qianfan.resources.llm.chat_completion.ChatCompletion.do", mock_qianfan_do)
mocker.patch("qianfan.resources.llm.chat_completion.ChatCompletion.ado", mock_qianfan_ado)
qianfan_llm = QianFanLLM(mock_llm_config_qianfan)
resp = qianfan_llm.completion(messages)
assert resp.get("result") == resp_cont
resp = await qianfan_llm.acompletion(messages)
assert resp.get("result") == resp_cont
await llm_general_chat_funcs_test(qianfan_llm, prompt, messages, resp_cont)

View file

@ -4,12 +4,18 @@
import pytest
from metagpt.config2 import Config
from metagpt.provider.spark_api import GetMessageFromWeb, SparkLLM
from tests.metagpt.provider.mock_llm_config import mock_llm_config
from tests.metagpt.provider.mock_llm_config import (
mock_llm_config,
mock_llm_config_spark,
)
from tests.metagpt.provider.req_resp_const import (
llm_general_chat_funcs_test,
prompt,
resp_cont_tmpl,
)
prompt_msg = "who are you"
resp_content = "I'm Spark"
resp_cont = resp_cont_tmpl.format(name="Spark")
class MockWebSocketApp(object):
@ -23,7 +29,7 @@ class MockWebSocketApp(object):
def test_get_msg_from_web(mocker):
mocker.patch("websocket.WebSocketApp", MockWebSocketApp)
get_msg_from_web = GetMessageFromWeb(prompt_msg, mock_llm_config)
get_msg_from_web = GetMessageFromWeb(prompt, mock_llm_config)
assert get_msg_from_web.gen_params()["parameter"]["chat"]["domain"] == "mock_domain"
ret = get_msg_from_web.run()
@ -31,34 +37,26 @@ def test_get_msg_from_web(mocker):
def mock_spark_get_msg_from_web_run(self) -> str:
return resp_content
return resp_cont
@pytest.mark.asyncio
async def test_spark_aask():
llm = SparkLLM(Config.from_home("spark.yaml").llm)
async def test_spark_aask(mocker):
mocker.patch("metagpt.provider.spark_api.GetMessageFromWeb.run", mock_spark_get_msg_from_web_run)
llm = SparkLLM(mock_llm_config_spark)
resp = await llm.aask("Hello!")
print(resp)
assert resp == resp_cont
@pytest.mark.asyncio
async def test_spark_acompletion(mocker):
mocker.patch("metagpt.provider.spark_api.GetMessageFromWeb.run", mock_spark_get_msg_from_web_run)
spark_gpt = SparkLLM(mock_llm_config)
spark_llm = SparkLLM(mock_llm_config)
resp = await spark_gpt.acompletion([])
assert resp == resp_content
resp = await spark_llm.acompletion([])
assert resp == resp_cont
resp = await spark_gpt.aask(prompt_msg, stream=False)
assert resp == resp_content
resp = await spark_gpt.acompletion_text([], stream=False)
assert resp == resp_content
resp = await spark_gpt.acompletion_text([], stream=True)
assert resp == resp_content
resp = await spark_gpt.aask(prompt_msg)
assert resp == resp_content
await llm_general_chat_funcs_test(spark_llm, prompt, prompt, resp_cont)

View file

@ -6,22 +6,24 @@ import pytest
from metagpt.provider.zhipuai_api import ZhiPuAILLM
from tests.metagpt.provider.mock_llm_config import mock_llm_config_zhipu
from tests.metagpt.provider.req_resp_const import (
get_part_chat_completion,
llm_general_chat_funcs_test,
messages,
prompt,
resp_cont_tmpl,
)
prompt_msg = "who are you"
messages = [{"role": "user", "content": prompt_msg}]
resp_content = "I'm chatglm-turbo"
default_resp = {
"choices": [{"finish_reason": "stop", "index": 0, "message": {"content": resp_content, "role": "assistant"}}],
"usage": {"completion_tokens": 22, "prompt_tokens": 19, "total_tokens": 41},
}
name = "ChatGLM-4"
resp_cont = resp_cont_tmpl.format(name=name)
default_resp = get_part_chat_completion(name)
async def mock_zhipuai_acreate_stream(**kwargs):
async def mock_zhipuai_acreate_stream(self, **kwargs):
class MockResponse(object):
async def _aread(self):
class Iterator(object):
events = [{"choices": [{"index": 0, "delta": {"content": resp_content, "role": "assistant"}}]}]
events = [{"choices": [{"index": 0, "delta": {"content": resp_cont, "role": "assistant"}}]}]
async def __aiter__(self):
for event in self.events:
@ -37,7 +39,7 @@ async def mock_zhipuai_acreate_stream(**kwargs):
return MockResponse()
async def mock_zhipuai_acreate(**kwargs) -> dict:
async def mock_zhipuai_acreate(self, **kwargs) -> dict:
return default_resp
@ -46,22 +48,12 @@ async def test_zhipuai_acompletion(mocker):
mocker.patch("metagpt.provider.zhipuai.zhipu_model_api.ZhiPuModelAPI.acreate", mock_zhipuai_acreate)
mocker.patch("metagpt.provider.zhipuai.zhipu_model_api.ZhiPuModelAPI.acreate_stream", mock_zhipuai_acreate_stream)
zhipu_gpt = ZhiPuAILLM(mock_llm_config_zhipu)
zhipu_llm = ZhiPuAILLM(mock_llm_config_zhipu)
resp = await zhipu_gpt.acompletion(messages)
assert resp["choices"][0]["message"]["content"] == resp_content
resp = await zhipu_llm.acompletion(messages)
assert resp["choices"][0]["message"]["content"] == resp_cont
resp = await zhipu_gpt.aask(prompt_msg, stream=False)
assert resp == resp_content
resp = await zhipu_gpt.acompletion_text(messages, stream=False)
assert resp == resp_content
resp = await zhipu_gpt.acompletion_text(messages, stream=True)
assert resp == resp_content
resp = await zhipu_gpt.aask(prompt_msg)
assert resp == resp_content
await llm_general_chat_funcs_test(zhipu_llm, prompt, messages, resp_cont)
def test_zhipuai_proxy():

View file

@ -0,0 +1,166 @@
import pytest
from llama_index.core import VectorStoreIndex
from llama_index.core.schema import Document, TextNode
from metagpt.rag.engines import SimpleEngine
from metagpt.rag.retrievers.base import ModifiableRAGRetriever
class TestSimpleEngine:
@pytest.fixture
def mock_simple_directory_reader(self, mocker):
return mocker.patch("metagpt.rag.engines.simple.SimpleDirectoryReader")
@pytest.fixture
def mock_vector_store_index(self, mocker):
return mocker.patch("metagpt.rag.engines.simple.VectorStoreIndex.from_documents")
@pytest.fixture
def mock_get_retriever(self, mocker):
return mocker.patch("metagpt.rag.engines.simple.get_retriever")
@pytest.fixture
def mock_get_rankers(self, mocker):
return mocker.patch("metagpt.rag.engines.simple.get_rankers")
@pytest.fixture
def mock_get_response_synthesizer(self, mocker):
return mocker.patch("metagpt.rag.engines.simple.get_response_synthesizer")
def test_from_docs(
self,
mocker,
mock_simple_directory_reader,
mock_vector_store_index,
mock_get_retriever,
mock_get_rankers,
mock_get_response_synthesizer,
):
# Mock
mock_simple_directory_reader.return_value.load_data.return_value = [
Document(text="document1"),
Document(text="document2"),
]
mock_get_retriever.return_value = mocker.MagicMock()
mock_get_rankers.return_value = [mocker.MagicMock()]
mock_get_response_synthesizer.return_value = mocker.MagicMock()
# Setup
input_dir = "test_dir"
input_files = ["test_file1", "test_file2"]
transformations = [mocker.MagicMock()]
embed_model = mocker.MagicMock()
llm = mocker.MagicMock()
retriever_configs = [mocker.MagicMock()]
ranker_configs = [mocker.MagicMock()]
# Execute
engine = SimpleEngine.from_docs(
input_dir=input_dir,
input_files=input_files,
transformations=transformations,
embed_model=embed_model,
llm=llm,
retriever_configs=retriever_configs,
ranker_configs=ranker_configs,
)
# Assertions
mock_simple_directory_reader.assert_called_once_with(input_dir=input_dir, input_files=input_files)
mock_vector_store_index.assert_called_once()
mock_get_retriever.assert_called_once_with(
configs=retriever_configs, index=mock_vector_store_index.return_value
)
mock_get_rankers.assert_called_once_with(configs=ranker_configs, llm=llm)
mock_get_response_synthesizer.assert_called_once_with(llm=llm)
assert isinstance(engine, SimpleEngine)
@pytest.mark.asyncio
async def test_asearch(self, mocker):
# Mock
test_query = "test query"
expected_result = "expected result"
mock_aquery = mocker.AsyncMock(return_value=expected_result)
# Setup
engine = SimpleEngine(retriever=mocker.MagicMock())
engine.aquery = mock_aquery
# Execute
result = await engine.asearch(test_query)
# Assertions
mock_aquery.assert_called_once_with(test_query)
assert result == expected_result
@pytest.mark.asyncio
async def test_aretrieve(self, mocker):
# Mock
mock_query_bundle = mocker.patch("metagpt.rag.engines.simple.QueryBundle", return_value="query_bundle")
mock_super_aretrieve = mocker.patch(
"metagpt.rag.engines.simple.RetrieverQueryEngine.aretrieve", new_callable=mocker.AsyncMock
)
mock_super_aretrieve.return_value = [TextNode(text="node_with_score", metadata={"is_obj": False})]
# Setup
engine = SimpleEngine(retriever=mocker.MagicMock())
test_query = "test query"
# Execute
result = await engine.aretrieve(test_query)
# Assertions
mock_query_bundle.assert_called_once_with(test_query)
mock_super_aretrieve.assert_called_once_with("query_bundle")
assert result[0].text == "node_with_score"
def test_add_docs(self, mocker):
# Mock
mock_simple_directory_reader = mocker.patch("metagpt.rag.engines.simple.SimpleDirectoryReader")
mock_simple_directory_reader.return_value.load_data.return_value = [
Document(text="document1"),
Document(text="document2"),
]
mock_retriever = mocker.MagicMock(spec=ModifiableRAGRetriever)
mock_index = mocker.MagicMock(spec=VectorStoreIndex)
mock_index._transformations = mocker.MagicMock()
mock_run_transformations = mocker.patch("metagpt.rag.engines.simple.run_transformations")
mock_run_transformations.return_value = ["node1", "node2"]
# Setup
engine = SimpleEngine(retriever=mock_retriever, index=mock_index)
input_files = ["test_file1", "test_file2"]
# Execute
engine.add_docs(input_files=input_files)
# Assertions
mock_simple_directory_reader.assert_called_once_with(input_files=input_files)
mock_retriever.add_nodes.assert_called_once_with(["node1", "node2"])
def test_add_objs(self, mocker):
# Mock
mock_retriever = mocker.MagicMock(spec=ModifiableRAGRetriever)
# Setup
class CustomTextNode(TextNode):
def rag_key(self):
return ""
def model_dump_json(self):
return ""
objs = [CustomTextNode(text=f"text_{i}", metadata={"obj": f"obj_{i}"}) for i in range(2)]
engine = SimpleEngine(retriever=mock_retriever, index=mocker.MagicMock())
# Execute
engine.add_objs(objs=objs)
# Assertions
assert mock_retriever.add_nodes.call_count == 1
for node in mock_retriever.add_nodes.call_args[0][0]:
assert isinstance(node, TextNode)
assert "is_obj" in node.metadata

View file

@ -0,0 +1,102 @@
import pytest
from metagpt.rag.factories.base import ConfigBasedFactory, GenericFactory
class TestGenericFactory:
@pytest.fixture
def creators(self):
return {
"type1": lambda name: f"Instance of type1 with {name}",
"type2": lambda name: f"Instance of type2 with {name}",
}
@pytest.fixture
def factory(self, creators):
return GenericFactory(creators=creators)
def test_get_instance_success(self, factory):
# Test successful retrieval of an instance
key = "type1"
instance = factory.get_instance(key, name="TestName")
assert instance == "Instance of type1 with TestName"
def test_get_instance_failure(self, factory):
# Test failure to retrieve an instance due to unregistered key
with pytest.raises(ValueError) as exc_info:
factory.get_instance("unknown_key")
assert "Creator not registered for key: unknown_key" in str(exc_info.value)
def test_get_instances_success(self, factory):
# Test successful retrieval of multiple instances
keys = ["type1", "type2"]
instances = factory.get_instances(keys, name="TestName")
expected = ["Instance of type1 with TestName", "Instance of type2 with TestName"]
assert instances == expected
@pytest.mark.parametrize(
"keys,expected_exception_message",
[
(["unknown_key"], "Creator not registered for key: unknown_key"),
(["type1", "unknown_key"], "Creator not registered for key: unknown_key"),
],
)
def test_get_instances_with_failure(self, factory, keys, expected_exception_message):
# Test failure to retrieve instances due to at least one unregistered key
with pytest.raises(ValueError) as exc_info:
factory.get_instances(keys, name="TestName")
assert expected_exception_message in str(exc_info.value)
class DummyConfig:
"""A dummy config class for testing."""
def __init__(self, name):
self.name = name
class TestConfigBasedFactory:
@pytest.fixture
def config_creators(self):
return {
DummyConfig: lambda config, **kwargs: f"Processed {config.name} with {kwargs.get('extra', 'no extra')}",
}
@pytest.fixture
def config_factory(self, config_creators):
return ConfigBasedFactory(creators=config_creators)
def test_get_instance_success(self, config_factory):
# Test successful retrieval of an instance
config = DummyConfig(name="TestConfig")
instance = config_factory.get_instance(config, extra="additional data")
assert instance == "Processed TestConfig with additional data"
def test_get_instance_failure(self, config_factory):
# Test failure to retrieve an instance due to unknown config type
class UnknownConfig:
pass
config = UnknownConfig()
with pytest.raises(ValueError) as exc_info:
config_factory.get_instance(config)
assert "Unknown config:" in str(exc_info.value)
def test_val_from_config_or_kwargs_priority(self):
# Test that the value from the config object has priority over kwargs
config = DummyConfig(name="ConfigName")
result = ConfigBasedFactory._val_from_config_or_kwargs("name", config, name="KwargsName")
assert result == "ConfigName"
def test_val_from_config_or_kwargs_fallback_to_kwargs(self):
# Test fallback to kwargs when config object does not have the value
config = DummyConfig(name=None)
result = ConfigBasedFactory._val_from_config_or_kwargs("name", config, name="KwargsName")
assert result == "KwargsName"
def test_val_from_config_or_kwargs_key_error(self):
# Test KeyError when the key is not found in both config object and kwargs
config = DummyConfig(name=None)
with pytest.raises(KeyError) as exc_info:
ConfigBasedFactory._val_from_config_or_kwargs("missing_key", config)
assert "The key 'missing_key' is required but not provided" in str(exc_info.value)

View file

@ -0,0 +1,41 @@
import pytest
from llama_index.core.llms import LLM
from llama_index.core.postprocessor import LLMRerank
from metagpt.rag.factories.ranker import RankerFactory
from metagpt.rag.schema import LLMRankerConfig
class TestRankerFactory:
@pytest.fixture
def ranker_factory(self) -> RankerFactory:
return RankerFactory()
@pytest.fixture
def mock_llm(self, mocker):
return mocker.MagicMock(spec=LLM)
def test_get_rankers_with_no_configs(self, ranker_factory: RankerFactory, mock_llm, mocker):
mocker.patch.object(ranker_factory, "_extract_llm", return_value=mock_llm)
default_rankers = ranker_factory.get_rankers()
assert len(default_rankers) == 0
def test_get_rankers_with_configs(self, ranker_factory: RankerFactory, mock_llm):
mock_config = LLMRankerConfig(llm=mock_llm)
rankers = ranker_factory.get_rankers(configs=[mock_config])
assert len(rankers) == 1
assert isinstance(rankers[0], LLMRerank)
def test_create_llm_ranker_creates_correct_instance(self, ranker_factory: RankerFactory, mock_llm):
mock_config = LLMRankerConfig(llm=mock_llm)
ranker = ranker_factory._create_llm_ranker(mock_config)
assert isinstance(ranker, LLMRerank)
def test_extract_llm_from_config(self, ranker_factory: RankerFactory, mock_llm):
mock_config = LLMRankerConfig(llm=mock_llm)
extracted_llm = ranker_factory._extract_llm(config=mock_config)
assert extracted_llm == mock_llm
def test_extract_llm_from_kwargs(self, ranker_factory: RankerFactory, mock_llm):
extracted_llm = ranker_factory._extract_llm(llm=mock_llm)
assert extracted_llm == mock_llm

View file

@ -0,0 +1,79 @@
import faiss
import pytest
from llama_index.core import VectorStoreIndex
from metagpt.rag.factories.retriever import RetrieverFactory
from metagpt.rag.retrievers.bm25_retriever import DynamicBM25Retriever
from metagpt.rag.retrievers.faiss_retriever import FAISSRetriever
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
from metagpt.rag.schema import BM25RetrieverConfig, FAISSRetrieverConfig
class TestRetrieverFactory:
@pytest.fixture
def retriever_factory(self):
return RetrieverFactory()
@pytest.fixture
def mock_faiss_index(self, mocker):
return mocker.MagicMock(spec=faiss.IndexFlatL2)
@pytest.fixture
def mock_vector_store_index(self, mocker):
mock = mocker.MagicMock(spec=VectorStoreIndex)
mock._embed_model = mocker.MagicMock()
mock.docstore.docs.values.return_value = []
return mock
def test_get_retriever_with_faiss_config(
self, retriever_factory: RetrieverFactory, mock_faiss_index, mocker, mock_vector_store_index
):
mock_config = FAISSRetrieverConfig(dimensions=128)
mocker.patch("faiss.IndexFlatL2", return_value=mock_faiss_index)
mocker.patch.object(retriever_factory, "_extract_index", return_value=mock_vector_store_index)
retriever = retriever_factory.get_retriever(configs=[mock_config])
assert isinstance(retriever, FAISSRetriever)
def test_get_retriever_with_bm25_config(self, retriever_factory: RetrieverFactory, mocker, mock_vector_store_index):
mock_config = BM25RetrieverConfig()
mocker.patch("rank_bm25.BM25Okapi.__init__", return_value=None)
mocker.patch.object(retriever_factory, "_extract_index", return_value=mock_vector_store_index)
retriever = retriever_factory.get_retriever(configs=[mock_config])
assert isinstance(retriever, DynamicBM25Retriever)
def test_get_retriever_with_multiple_configs_returns_hybrid(
self, retriever_factory: RetrieverFactory, mocker, mock_vector_store_index
):
mock_faiss_config = FAISSRetrieverConfig(dimensions=128)
mock_bm25_config = BM25RetrieverConfig()
mocker.patch("rank_bm25.BM25Okapi.__init__", return_value=None)
mocker.patch.object(retriever_factory, "_extract_index", return_value=mock_vector_store_index)
retriever = retriever_factory.get_retriever(configs=[mock_faiss_config, mock_bm25_config])
assert isinstance(retriever, SimpleHybridRetriever)
def test_create_default_retriever(self, retriever_factory: RetrieverFactory, mocker, mock_vector_store_index):
mocker.patch.object(retriever_factory, "_extract_index", return_value=mock_vector_store_index)
mock_vector_store_index.as_retriever = mocker.MagicMock()
retriever = retriever_factory.get_retriever()
mock_vector_store_index.as_retriever.assert_called_once()
assert retriever is mock_vector_store_index.as_retriever.return_value
def test_extract_index_from_config(self, retriever_factory: RetrieverFactory, mock_vector_store_index):
mock_config = FAISSRetrieverConfig(index=mock_vector_store_index)
extracted_index = retriever_factory._extract_index(config=mock_config)
assert extracted_index == mock_vector_store_index
def test_extract_index_from_kwargs(self, retriever_factory: RetrieverFactory, mock_vector_store_index):
extracted_index = retriever_factory._extract_index(index=mock_vector_store_index)
assert extracted_index == mock_vector_store_index

View file

@ -0,0 +1,37 @@
import pytest
from llama_index.core import VectorStoreIndex
from llama_index.core.schema import Node
from metagpt.rag.retrievers.bm25_retriever import DynamicBM25Retriever
class TestDynamicBM25Retriever:
@pytest.fixture(autouse=True)
def setup(self, mocker):
# 创建模拟的Document对象
self.doc1 = mocker.MagicMock(spec=Node)
self.doc1.get_content.return_value = "Document content 1"
self.doc2 = mocker.MagicMock(spec=Node)
self.doc2.get_content.return_value = "Document content 2"
self.mock_nodes = [self.doc1, self.doc2]
# 模拟index
index = mocker.MagicMock(spec=VectorStoreIndex)
# 模拟nodes和tokenizer参数
mock_nodes = []
mock_tokenizer = mocker.MagicMock()
self.mock_bm25okapi = mocker.patch("rank_bm25.BM25Okapi.__init__", return_value=None)
# 初始化DynamicBM25Retriever对象并提供必需的参数
self.retriever = DynamicBM25Retriever(nodes=mock_nodes, tokenizer=mock_tokenizer, index=index)
def test_add_docs_updates_nodes_and_corpus(self):
# Execute
self.retriever.add_nodes(self.mock_nodes)
# Assertions
assert len(self.retriever._nodes) == len(self.mock_nodes)
assert len(self.retriever._corpus) == len(self.mock_nodes)
self.retriever._tokenizer.assert_called()
self.mock_bm25okapi.assert_called()

View file

@ -0,0 +1,22 @@
import pytest
from llama_index.core.schema import Node
from metagpt.rag.retrievers.faiss_retriever import FAISSRetriever
class TestFAISSRetriever:
@pytest.fixture(autouse=True)
def setup(self, mocker):
# 创建模拟的Document对象
self.doc1 = mocker.MagicMock(spec=Node)
self.doc2 = mocker.MagicMock(spec=Node)
self.mock_nodes = [self.doc1, self.doc2]
# 模拟FAISSRetriever的_index属性
self.mock_index = mocker.MagicMock()
self.retriever = FAISSRetriever(self.mock_index)
def test_add_docs_calls_insert_for_each_document(self, mocker):
self.retriever.add_nodes(self.mock_nodes)
assert self.mock_index.insert_nodes.assert_called

View file

@ -0,0 +1,39 @@
from unittest.mock import AsyncMock
import pytest
from llama_index.core.schema import NodeWithScore, TextNode
from metagpt.rag.retrievers import SimpleHybridRetriever
class TestSimpleHybridRetriever:
@pytest.mark.asyncio
async def test_aretrieve(self):
question = "test query"
# Create mock retrievers
mock_retriever1 = AsyncMock()
mock_retriever1.aretrieve.return_value = [
NodeWithScore(node=TextNode(id_="1"), score=1.0),
NodeWithScore(node=TextNode(id_="2"), score=0.95),
]
mock_retriever2 = AsyncMock()
mock_retriever2.aretrieve.return_value = [
NodeWithScore(node=TextNode(id_="2"), score=0.95),
NodeWithScore(node=TextNode(id_="3"), score=0.8),
]
# Instantiate the SimpleHybridRetriever with the mock retrievers
hybrid_retriever = SimpleHybridRetriever(mock_retriever1, mock_retriever2)
# Call the _aretrieve method
results = await hybrid_retriever._aretrieve(question)
# Check if the results are as expected
assert len(results) == 3 # Should be 3 unique nodes
assert set(node.node.node_id for node in results) == {"1", "2", "3"}
# Check if the scores are correct (assuming you want the highest score)
node_scores = {node.node.node_id: node.score for node in results}
assert node_scores["2"] == 0.95

View file

@ -1,19 +0,0 @@
import pytest
from metagpt.logs import logger
from metagpt.roles.ci.code_interpreter import CodeInterpreter
@pytest.mark.asyncio
@pytest.mark.parametrize("auto_run", [(True), (False)])
async def test_code_interpreter(mocker, auto_run):
mocker.patch("metagpt.actions.ci.execute_nb_code.ExecuteNbCode.run", return_value=("a successful run", True))
mocker.patch("builtins.input", return_value="confirm")
requirement = "Run data analysis on sklearn Iris dataset, include a plot"
tools = []
ci = CodeInterpreter(auto_run=auto_run, use_tools=True, tools=tools)
rsp = await ci.run(requirement)
logger.info(rsp)
assert len(rsp.content) > 0

View file

@ -1,90 +0,0 @@
import pytest
from metagpt.actions.ci.execute_nb_code import ExecuteNbCode
from metagpt.logs import logger
from metagpt.roles.ci.ml_engineer import MLEngineer
from metagpt.schema import Message, Plan, Task
from metagpt.tools.tool_type import ToolType
from tests.metagpt.actions.ci.test_debug_code import CODE, DebugContext, ErrorStr
def test_mle_init():
ci = MLEngineer(goal="test", auto_run=True, use_tools=True, tools=["tool1", "tool2"])
assert ci.tools == []
MockPlan = Plan(
goal="This is a titanic passenger survival dataset, your goal is to predict passenger survival outcome. The target column is Survived. Perform data analysis, data preprocessing, feature engineering, and modeling to predict the target. Report accuracy on the eval data. Train data path: 'tests/data/ml_datasets/titanic/split_train.csv', eval data path: 'tests/data/ml_datasets/titanic/split_eval.csv'.",
context="",
tasks=[
Task(
task_id="1",
dependent_task_ids=[],
instruction="Perform exploratory data analysis on the train dataset to understand the features and target variable.",
task_type="eda",
code="",
result="",
is_success=False,
is_finished=False,
)
],
task_map={
"1": Task(
task_id="1",
dependent_task_ids=[],
instruction="Perform exploratory data analysis on the train dataset to understand the features and target variable.",
task_type="eda",
code="",
result="",
is_success=False,
is_finished=False,
)
},
current_task_id="1",
)
@pytest.mark.asyncio
async def test_mle_write_code(mocker):
data_path = "tests/data/ml_datasets/titanic"
mle = MLEngineer(auto_run=True, use_tools=True)
mle.planner.plan = MockPlan
code, _ = await mle._write_code()
assert data_path in code["code"]
@pytest.mark.asyncio
async def test_mle_update_data_columns(mocker):
mle = MLEngineer(auto_run=True, use_tools=True)
mle.planner.plan = MockPlan
# manually update task type to test update
mle.planner.plan.current_task.task_type = ToolType.DATA_PREPROCESS.value
result = await mle._update_data_columns()
assert result is not None
@pytest.mark.asyncio
async def test_mle_debug_code(mocker):
mle = MLEngineer(auto_run=True, use_tools=True)
mle.working_memory.add(Message(content=ErrorStr, cause_by=ExecuteNbCode))
mle.latest_code = CODE
mle.debug_context = DebugContext
code, _ = await mle._write_code()
assert len(code) > 0
@pytest.mark.skip
@pytest.mark.asyncio
async def test_ml_engineer():
data_path = "tests/data/ml_datasets/titanic"
requirement = f"This is a titanic passenger survival dataset, your goal is to predict passenger survival outcome. The target column is Survived. Perform data analysis, data preprocessing, feature engineering, and modeling to predict the target. Report accuracy on the eval data. Train data path: '{data_path}/split_train.csv', eval data path: '{data_path}/split_eval.csv'."
tools = ["FillMissingValue", "CatCross", "dummy_tool"]
mle = MLEngineer(auto_run=True, use_tools=True, tools=tools)
rsp = await mle.run(requirement)
logger.info(rsp)
assert len(rsp.content) > 0

View file

@ -0,0 +1,34 @@
import pytest
from metagpt.logs import logger
from metagpt.roles.di.data_interpreter import DataInterpreter
@pytest.mark.asyncio
@pytest.mark.parametrize("auto_run", [(True), (False)])
async def test_interpreter(mocker, auto_run):
mocker.patch("metagpt.actions.di.execute_nb_code.ExecuteNbCode.run", return_value=("a successful run", True))
mocker.patch("builtins.input", return_value="confirm")
requirement = "Run data analysis on sklearn Wine recognition dataset, include a plot, and train a model to predict wine class (20% as validation), and show validation accuracy."
di = DataInterpreter(auto_run=auto_run)
rsp = await di.run(requirement)
logger.info(rsp)
assert len(rsp.content) > 0
finished_tasks = di.planner.plan.get_finished_tasks()
assert len(finished_tasks) > 0
assert len(finished_tasks[0].code) > 0 # check one task to see if code is recorded
@pytest.mark.asyncio
async def test_interpreter_react_mode(mocker):
mocker.patch("metagpt.actions.di.execute_nb_code.ExecuteNbCode.run", return_value=("a successful run", True))
requirement = "Run data analysis on sklearn Wine recognition dataset, include a plot, and train a model to predict wine class (20% as validation), and show validation accuracy."
di = DataInterpreter(react_mode="react")
rsp = await di.run(requirement)
logger.info(rsp)
assert len(rsp.content) > 0

View file

@ -30,6 +30,17 @@ async def test_run(mocker, context):
language: str
agent_description: str
cause_by: str
agent_skills: list
agent_skills = [
{"id": 1, "name": "text_to_speech", "type": "builtin", "config": {}, "enabled": True},
{"id": 2, "name": "text_to_image", "type": "builtin", "config": {}, "enabled": True},
{"id": 3, "name": "ai_call", "type": "builtin", "config": {}, "enabled": True},
{"id": 3, "name": "data_analysis", "type": "builtin", "config": {}, "enabled": True},
{"id": 5, "name": "crawler", "type": "builtin", "config": {"engine": "ddg"}, "enabled": True},
{"id": 6, "name": "knowledge", "type": "builtin", "config": {}, "enabled": True},
{"id": 6, "name": "web_search", "type": "builtin", "config": {}, "enabled": True},
]
inputs = [
{
@ -48,6 +59,7 @@ async def test_run(mocker, context):
"language": "English",
"agent_description": "chatterbox",
"cause_by": any_to_str(TalkAction),
"agent_skills": [],
},
{
"memory": {
@ -65,24 +77,16 @@ async def test_run(mocker, context):
"language": "English",
"agent_description": "painter",
"cause_by": any_to_str(SkillAction),
"agent_skills": agent_skills,
},
]
agent_skills = [
{"id": 1, "name": "text_to_speech", "type": "builtin", "config": {}, "enabled": True},
{"id": 2, "name": "text_to_image", "type": "builtin", "config": {}, "enabled": True},
{"id": 3, "name": "ai_call", "type": "builtin", "config": {}, "enabled": True},
{"id": 3, "name": "data_analysis", "type": "builtin", "config": {}, "enabled": True},
{"id": 5, "name": "crawler", "type": "builtin", "config": {"engine": "ddg"}, "enabled": True},
{"id": 6, "name": "knowledge", "type": "builtin", "config": {}, "enabled": True},
{"id": 6, "name": "web_search", "type": "builtin", "config": {}, "enabled": True},
]
for i in inputs:
seed = Input(**i)
role = Assistant(language="Chinese", context=context)
role.context.kwargs.language = seed.language
role.context.kwargs.agent_description = seed.agent_description
role.context.kwargs.agent_skills = agent_skills
role.context.kwargs.agent_skills = seed.agent_skills
role.memory = seed.memory # Restore historical conversation content.
while True:

View file

@ -6,11 +6,11 @@
@File : test_tutorial_assistant.py
"""
import aiofiles
import pytest
from metagpt.const import TUTORIAL_PATH
from metagpt.roles.tutorial_assistant import TutorialAssistant
from metagpt.utils.common import aread
@pytest.mark.asyncio
@ -20,9 +20,8 @@ async def test_tutorial_assistant(language: str, topic: str, context):
msg = await role.run(topic)
assert TUTORIAL_PATH.exists()
filename = msg.content
async with aiofiles.open(filename, mode="r", encoding="utf-8") as reader:
content = await reader.read()
assert "pip" in content
content = await aread(filename=filename)
assert "pip" in content
if __name__ == "__main__":

View file

@ -0,0 +1,37 @@
from metagpt.schema import Plan, Task
from metagpt.strategy.planner import Planner
from metagpt.strategy.task_type import TaskType
MOCK_TASK_MAP = {
"1": Task(
task_id="1",
instruction="test instruction for finished task",
task_type=TaskType.EDA.type_name,
dependent_task_ids=[],
code="some finished test code",
result="some finished test result",
is_finished=True,
),
"2": Task(
task_id="2",
instruction="test instruction for current task",
task_type=TaskType.DATA_PREPROCESS.type_name,
dependent_task_ids=["1"],
),
}
MOCK_PLAN = Plan(
goal="test goal",
tasks=list(MOCK_TASK_MAP.values()),
task_map=MOCK_TASK_MAP,
current_task_id="2",
)
def test_planner_get_plan_status():
planner = Planner(plan=MOCK_PLAN)
status = planner.get_plan_status()
assert "some finished test code" in status
assert "some finished test result" in status
assert "test instruction for current task" in status
assert TaskType.DATA_PREPROCESS.value.guidance in status # current task guidance

View file

@ -1,9 +1,11 @@
from pathlib import Path
from pprint import pformat
import pytest
from metagpt.const import METAGPT_ROOT
from metagpt.logs import logger
from metagpt.repo_parser import RepoParser
from metagpt.repo_parser import DotClassAttribute, DotClassMethod, DotReturn, RepoParser
def test_repo_parser():
@ -23,3 +25,140 @@ def test_error():
"""_parse_file should return empty list when file not existed"""
rsp = RepoParser._parse_file(Path("test_not_existed_file.py"))
assert rsp == []
@pytest.mark.parametrize(
("v", "name", "type_", "default_", "compositions"),
[
("children : dict[str, 'ActionNode']", "children", "dict[str,ActionNode]", "", ["ActionNode"]),
("context : str", "context", "str", "", []),
("example", "example", "", "", []),
("expected_type : Type", "expected_type", "Type", "", ["Type"]),
("args : Optional[Dict]", "args", "Optional[Dict]", "", []),
("rsp : Optional[Message] = Message.Default", "rsp", "Optional[Message]", "Message.Default", ["Message"]),
(
"browser : Literal['chrome', 'firefox', 'edge', 'ie']",
"browser",
"Literal['chrome','firefox','edge','ie']",
"",
[],
),
(
"browser : Dict[ Message, Literal['chrome', 'firefox', 'edge', 'ie'] ]",
"browser",
"Dict[Message,Literal['chrome','firefox','edge','ie']]",
"",
["Message"],
),
("attributes : List[ClassAttribute]", "attributes", "List[ClassAttribute]", "", ["ClassAttribute"]),
("attributes = []", "attributes", "", "[]", []),
(
"request_timeout: Optional[Union[float, Tuple[float, float]]]",
"request_timeout",
"Optional[Union[float,Tuple[float,float]]]",
"",
[],
),
],
)
def test_parse_member(v, name, type_, default_, compositions):
attr = DotClassAttribute.parse(v)
assert name == attr.name
assert type_ == attr.type_
assert default_ == attr.default_
assert compositions == attr.compositions
assert v == attr.description
json_data = attr.model_dump_json()
v = DotClassAttribute.model_validate_json(json_data)
assert v == attr
@pytest.mark.parametrize(
("line", "package_name", "info"),
[
(
'"metagpt.roles.architect.Architect" [color="black", fontcolor="black", label=<{Architect|constraints : str<br ALIGN="LEFT"/>goal : str<br ALIGN="LEFT"/>name : str<br ALIGN="LEFT"/>profile : str<br ALIGN="LEFT"/>|}>, shape="record", style="solid"];',
"metagpt.roles.architect.Architect",
"Architect|constraints : str\ngoal : str\nname : str\nprofile : str\n|",
),
(
'"metagpt.actions.skill_action.ArgumentsParingAction" [color="black", fontcolor="black", label=<{ArgumentsParingAction|args : Optional[Dict]<br ALIGN="LEFT"/>ask : str<br ALIGN="LEFT"/>prompt<br ALIGN="LEFT"/>rsp : Optional[Message]<br ALIGN="LEFT"/>skill<br ALIGN="LEFT"/>|parse_arguments(skill_name, txt): dict<br ALIGN="LEFT"/>run(with_message): Message<br ALIGN="LEFT"/>}>, shape="record", style="solid"];',
"metagpt.actions.skill_action.ArgumentsParingAction",
"ArgumentsParingAction|args : Optional[Dict]\nask : str\nprompt\nrsp : Optional[Message]\nskill\n|parse_arguments(skill_name, txt): dict\nrun(with_message): Message\n",
),
(
'"metagpt.strategy.base.BaseEvaluator" [color="black", fontcolor="black", label=<{BaseEvaluator|<br ALIGN="LEFT"/>|<I>status_verify</I>()<br ALIGN="LEFT"/>}>, shape="record", style="solid"];',
"metagpt.strategy.base.BaseEvaluator",
"BaseEvaluator|\n|<I>status_verify</I>()\n",
),
(
'"metagpt.configs.browser_config.BrowserConfig" [color="black", fontcolor="black", label=<{BrowserConfig|browser : Literal[\'chrome\', \'firefox\', \'edge\', \'ie\']<br ALIGN="LEFT"/>driver : Literal[\'chromium\', \'firefox\', \'webkit\']<br ALIGN="LEFT"/>engine<br ALIGN="LEFT"/>path : str<br ALIGN="LEFT"/>|}>, shape="record", style="solid"];',
"metagpt.configs.browser_config.BrowserConfig",
"BrowserConfig|browser : Literal['chrome', 'firefox', 'edge', 'ie']\ndriver : Literal['chromium', 'firefox', 'webkit']\nengine\npath : str\n|",
),
(
'"metagpt.tools.search_engine_serpapi.SerpAPIWrapper" [color="black", fontcolor="black", label=<{SerpAPIWrapper|aiosession : Optional[aiohttp.ClientSession]<br ALIGN="LEFT"/>model_config<br ALIGN="LEFT"/>params : dict<br ALIGN="LEFT"/>search_engine : Optional[Any]<br ALIGN="LEFT"/>serpapi_api_key : Optional[str]<br ALIGN="LEFT"/>|check_serpapi_api_key(val: str)<br ALIGN="LEFT"/>get_params(query: str): Dict[str, str]<br ALIGN="LEFT"/>results(query: str, max_results: int): dict<br ALIGN="LEFT"/>run(query, max_results: int, as_string: bool): str<br ALIGN="LEFT"/>}>, shape="record", style="solid"];',
"metagpt.tools.search_engine_serpapi.SerpAPIWrapper",
"SerpAPIWrapper|aiosession : Optional[aiohttp.ClientSession]\nmodel_config\nparams : dict\nsearch_engine : Optional[Any]\nserpapi_api_key : Optional[str]\n|check_serpapi_api_key(val: str)\nget_params(query: str): Dict[str, str]\nresults(query: str, max_results: int): dict\nrun(query, max_results: int, as_string: bool): str\n",
),
],
)
def test_split_class_line(line, package_name, info):
p, i = RepoParser._split_class_line(line)
assert p == package_name
assert i == info
@pytest.mark.parametrize(
("v", "name", "args", "return_args"),
[
(
"<I>arequest</I>(method, url, params, headers, files, stream: Literal[True], request_id: Optional[str], request_timeout: Optional[Union[float, Tuple[float, float]]]): Tuple[AsyncGenerator[OpenAIResponse, None], bool, str]",
"arequest",
[
DotClassAttribute(name="method", description="method"),
DotClassAttribute(name="url", description="url"),
DotClassAttribute(name="params", description="params"),
DotClassAttribute(name="headers", description="headers"),
DotClassAttribute(name="files", description="files"),
DotClassAttribute(name="stream", type_="Literal[True]", description="stream: Literal[True]"),
DotClassAttribute(name="request_id", type_="Optional[str]", description="request_id: Optional[str]"),
DotClassAttribute(
name="request_timeout",
type_="Optional[Union[float,Tuple[float,float]]]",
description="request_timeout: Optional[Union[float, Tuple[float, float]]]",
),
],
DotReturn(
type_="Tuple[AsyncGenerator[OpenAIResponse,None],bool,str]",
compositions=["AsyncGenerator", "OpenAIResponse"],
description="Tuple[AsyncGenerator[OpenAIResponse, None], bool, str]",
),
),
(
"<I>update</I>(subject: str, predicate: str, object_: str)",
"update",
[
DotClassAttribute(name="subject", type_="str", description="subject: str"),
DotClassAttribute(name="predicate", type_="str", description="predicate: str"),
DotClassAttribute(name="object_", type_="str", description="object_: str"),
],
DotReturn(description=""),
),
],
)
def test_parse_method(v, name, args, return_args):
method = DotClassMethod.parse(v)
assert method.name == name
assert method.args == args
assert method.return_args == return_args
assert method.description == v
json_data = method.model_dump_json()
v = DotClassMethod.model_validate_json(json_data)
assert v == method
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -18,9 +18,6 @@ from metagpt.actions.write_code import WriteCode
from metagpt.const import SYSTEM_DESIGN_FILE_REPO, TASK_FILE_REPO
from metagpt.schema import (
AIMessage,
ClassAttribute,
ClassMethod,
ClassView,
CodeSummarizeContext,
Document,
Message,
@ -28,6 +25,9 @@ from metagpt.schema import (
Plan,
SystemMessage,
Task,
UMLClassAttribute,
UMLClassMethod,
UMLClassView,
UserMessage,
)
from metagpt.utils.common import any_to_str
@ -159,27 +159,26 @@ def test_CodeSummarizeContext(file_list, want):
def test_class_view():
attr_a = ClassAttribute(name="a", value_type="int", default_value="0", visibility="+", abstraction=True)
assert attr_a.get_mermaid(align=1) == "\t+int a=0*"
attr_b = ClassAttribute(name="b", value_type="str", default_value="0", visibility="#", static=True)
assert attr_b.get_mermaid(align=0) == '#str b="0"$'
class_view = ClassView(name="A")
attr_a = UMLClassAttribute(name="a", value_type="int", default_value="0", visibility="+")
assert attr_a.get_mermaid(align=1) == "\t+int a=0"
attr_b = UMLClassAttribute(name="b", value_type="str", default_value="0", visibility="#")
assert attr_b.get_mermaid(align=0) == '#str b="0"'
class_view = UMLClassView(name="A")
class_view.attributes = [attr_a, attr_b]
method_a = ClassMethod(name="run", visibility="+", abstraction=True)
assert method_a.get_mermaid(align=1) == "\t+run()*"
method_b = ClassMethod(
method_a = UMLClassMethod(name="run", visibility="+")
assert method_a.get_mermaid(align=1) == "\t+run()"
method_b = UMLClassMethod(
name="_test",
visibility="#",
static=True,
args=[ClassAttribute(name="a", value_type="str"), ClassAttribute(name="b", value_type="int")],
args=[UMLClassAttribute(name="a", value_type="str"), UMLClassAttribute(name="b", value_type="int")],
return_type="str",
)
assert method_b.get_mermaid(align=0) == "#_test(str a,int b):str$"
assert method_b.get_mermaid(align=0) == "#_test(str a,int b) str"
class_view.methods = [method_a, method_b]
assert (
class_view.get_mermaid(align=0)
== 'class A{\n\t+int a=0*\n\t#str b="0"$\n\t+run()*\n\t#_test(str a,int b):str$\n}\n'
== 'class A{\n\t+int a=0\n\t#str b="0"\n\t+run()\n\t#_test(str a,int b) str\n}\n'
)

View file

@ -0,0 +1,7 @@
from metagpt.tools.libs.email_login import email_login_imap
def test_email_login(mocker):
mock_mailbox = mocker.patch("metagpt.tools.libs.email_login.MailBox.login")
mock_mailbox.login.return_value = mocker.Mock()
email_login_imap("test@outlook.com", "test_password")

View file

@ -60,18 +60,24 @@ async def test_generate_webpages(mock_webpage_filename_with_styles_and_scripts,
async def test_save_webpages_with_styles_and_scripts(mock_webpage_filename_with_styles_and_scripts, image_path):
generator = GPTvGenerator()
webpages = await generator.generate_webpages(image_path)
webpages_dir = generator.save_webpages(image_path=image_path, webpages=webpages)
webpages_dir = generator.save_webpages(webpages=webpages, save_folder_name="test_1")
logs.logger.info(webpages_dir)
assert webpages_dir.exists()
assert (webpages_dir / "index.html").exists()
assert (webpages_dir / "styles.css").exists()
assert (webpages_dir / "scripts.js").exists()
@pytest.mark.asyncio
async def test_save_webpages_with_style_and_script(mock_webpage_filename_with_style_and_script, image_path):
generator = GPTvGenerator()
webpages = await generator.generate_webpages(image_path)
webpages_dir = generator.save_webpages(image_path=image_path, webpages=webpages)
webpages_dir = generator.save_webpages(webpages=webpages, save_folder_name="test_2")
logs.logger.info(webpages_dir)
assert webpages_dir.exists()
assert (webpages_dir / "index.html").exists()
assert (webpages_dir / "style.css").exists()
assert (webpages_dir / "script.js").exists()
@pytest.mark.asyncio

View file

@ -4,8 +4,8 @@ from metagpt.tools.libs.web_scraping import scrape_web_playwright
@pytest.mark.asyncio
async def test_scrape_web_playwright():
test_url = "https://www.deepwisdom.ai"
async def test_scrape_web_playwright(http_server):
server, test_url = await http_server()
result = await scrape_web_playwright(test_url)
@ -21,3 +21,4 @@ async def test_scrape_web_playwright():
assert not result["inner_text"].endswith(" ")
assert not result["html"].startswith(" ")
assert not result["html"].endswith(" ")
await server.stop()

View file

@ -11,7 +11,6 @@ from typing import Callable
import pytest
from metagpt.config2 import config
from metagpt.configs.search_config import SearchConfig
from metagpt.logs import logger
from metagpt.tools import SearchEngineType
@ -53,14 +52,11 @@ async def test_search_engine(
search_engine_config = {"engine": search_engine_type, "run_func": run_func}
if search_engine_type is SearchEngineType.SERPAPI_GOOGLE:
assert config.search
search_engine_config["api_key"] = "mock-serpapi-key"
elif search_engine_type is SearchEngineType.DIRECT_GOOGLE:
assert config.search
search_engine_config["api_key"] = "mock-google-key"
search_engine_config["cse_id"] = "mock-google-cse"
elif search_engine_type is SearchEngineType.SERPER_GOOGLE:
assert config.search
search_engine_config["api_key"] = "mock-serper-key"
async def test(search_engine):

View file

@ -1,44 +1,8 @@
from typing import Literal, Union
import pandas as pd
from metagpt.tools.tool_convert import convert_code_to_tool_schema, docstring_to_schema
def test_docstring_to_schema():
docstring = """
Some test desc.
Args:
features (list): Columns to be processed.
strategy (str, optional): The imputation strategy, notice 'mean' and 'median' can only be
used for numeric features. Enum: ['mean', 'median', 'most_frequent', 'constant']. Defaults to 'mean'.
fill_value (int, optional): Fill_value is used to replace all occurrences of missing_values.
Defaults to None.
Returns:
pd.DataFrame: The transformed DataFrame.
"""
expected = {
"description": "Some test desc.",
"parameters": {
"properties": {
"features": {"type": "list", "description": "Columns to be processed."},
"strategy": {
"type": "str",
"description": "The imputation strategy, notice 'mean' and 'median' can only be used for numeric features. Enum: ['mean', 'median', 'most_frequent', 'constant']. Defaults to 'mean'.",
"default": "'mean'",
"enum": ["'mean'", "'median'", "'most_frequent'", "'constant'"],
},
"fill_value": {
"type": "int",
"description": "Fill_value is used to replace all occurrences of missing_values. Defaults to None.",
"default": "None",
},
},
"required": ["features"],
},
"returns": [{"type": "pd.DataFrame", "description": "The transformed DataFrame."}],
}
schema = docstring_to_schema(docstring)
assert schema == expected
from metagpt.tools.tool_convert import convert_code_to_tool_schema
class DummyClass:
@ -81,12 +45,25 @@ class DummyClass:
pass
def dummy_fn(df: pd.DataFrame) -> dict:
def dummy_fn(
df: pd.DataFrame,
s: str,
k: int = 5,
type: Literal["a", "b", "c"] = "a",
test_dict: dict[str, int] = None,
test_union: Union[str, list[str]] = "",
) -> dict:
"""
Analyzes a DataFrame and categorizes its columns based on data types.
Args:
df (pd.DataFrame): The DataFrame to be analyzed.
df: The DataFrame to be analyzed.
Another line for df.
s (str): Some test string param.
Another line for s.
k (int, optional): Some test integer param. Defaults to 5.
type (Literal["a", "b", "c"], optional): Some test type. Defaults to 'a'.
more_args: will be omitted here for testing
Returns:
dict: A dictionary with four keys ('Category', 'Numeric', 'Datetime', 'Others').
@ -115,41 +92,21 @@ def test_convert_code_to_tool_schema_class():
"methods": {
"__init__": {
"type": "function",
"description": "Initialize self.",
"parameters": {
"properties": {
"features": {"type": "list", "description": "Columns to be processed."},
"strategy": {
"type": "str",
"description": "The imputation strategy, notice 'mean' and 'median' can only be used for numeric features. Enum: ['mean', 'median', 'most_frequent', 'constant']. Defaults to 'mean'.",
"default": "'mean'",
"enum": ["'mean'", "'median'", "'most_frequent'", "'constant'"],
},
"fill_value": {
"type": "int",
"description": "Fill_value is used to replace all occurrences of missing_values. Defaults to None.",
"default": "None",
},
},
"required": ["features"],
},
"description": "Initialize self. ",
"signature": "(self, features: list, strategy: str = 'mean', fill_value=None)",
"parameters": "Args: features (list): Columns to be processed. strategy (str, optional): The imputation strategy, notice 'mean' and 'median' can only be used for numeric features. Enum: ['mean', 'median', 'most_frequent', 'constant']. Defaults to 'mean'. fill_value (int, optional): Fill_value is used to replace all occurrences of missing_values. Defaults to None.",
},
"fit": {
"type": "function",
"description": "Fit the FillMissingValue model.",
"parameters": {
"properties": {"df": {"type": "pd.DataFrame", "description": "The input DataFrame."}},
"required": ["df"],
},
"description": "Fit the FillMissingValue model. ",
"signature": "(self, df: pandas.core.frame.DataFrame)",
"parameters": "Args: df (pd.DataFrame): The input DataFrame.",
},
"transform": {
"type": "function",
"description": "Transform the input DataFrame with the fitted model.",
"parameters": {
"properties": {"df": {"type": "pd.DataFrame", "description": "The input DataFrame."}},
"required": ["df"],
},
"returns": [{"type": "pd.DataFrame", "description": "The transformed DataFrame."}],
"description": "Transform the input DataFrame with the fitted model. ",
"signature": "(self, df: pandas.core.frame.DataFrame) -> pandas.core.frame.DataFrame",
"parameters": "Args: df (pd.DataFrame): The input DataFrame. Returns: pd.DataFrame: The transformed DataFrame.",
},
},
}
@ -160,11 +117,9 @@ def test_convert_code_to_tool_schema_class():
def test_convert_code_to_tool_schema_function():
expected = {
"type": "function",
"description": "Analyzes a DataFrame and categorizes its columns based on data types.",
"parameters": {
"properties": {"df": {"type": "pd.DataFrame", "description": "The DataFrame to be analyzed."}},
"required": ["df"],
},
"description": "Analyzes a DataFrame and categorizes its columns based on data types. ",
"signature": "(df: pandas.core.frame.DataFrame, s: str, k: int = 5, type: Literal['a', 'b', 'c'] = 'a', test_dict: dict[str, int] = None, test_union: Union[str, list[str]] = '') -> dict",
"parameters": "Args: df: The DataFrame to be analyzed. Another line for df. s (str): Some test string param. Another line for s. k (int, optional): Some test integer param. Defaults to 5. type (Literal[\"a\", \"b\", \"c\"], optional): Some test type. Defaults to 'a'. more_args: will be omitted here for testing Returns: dict: A dictionary with four keys ('Category', 'Numeric', 'Datetime', 'Others'). Each key corresponds to a list of column names belonging to that category.",
}
schema = convert_code_to_tool_schema(dummy_fn)
assert schema == expected

View file

@ -0,0 +1,90 @@
import pytest
from metagpt.schema import Plan, Task
from metagpt.tools import TOOL_REGISTRY
from metagpt.tools.tool_recommend import (
BM25ToolRecommender,
ToolRecommender,
TypeMatchToolRecommender,
)
@pytest.fixture
def mock_plan(mocker):
task_map = {
"1": Task(
task_id="1",
instruction="conduct feature engineering, add new features on the dataset",
task_type="feature engineering",
)
}
plan = Plan(
goal="test requirement",
tasks=list(task_map.values()),
task_map=task_map,
current_task_id="1",
)
return plan
@pytest.fixture
def mock_bm25_tr(mocker):
tr = BM25ToolRecommender(tools=["FillMissingValue", "PolynomialExpansion", "web scraping"])
return tr
def test_tr_init():
tr = ToolRecommender(tools=["FillMissingValue", "PolynomialExpansion", "web scraping", "non-existing tool"])
# web_scraping is a tool tag, it has one tool scrape_web_playwright
assert list(tr.tools.keys()) == [
"FillMissingValue",
"PolynomialExpansion",
"scrape_web_playwright",
]
def test_tr_init_default_tools_value():
tr = ToolRecommender()
assert tr.tools == {}
def test_tr_init_tools_all():
tr = ToolRecommender(tools=["<all>"])
assert list(tr.tools.keys()) == list(TOOL_REGISTRY.get_all_tools().keys())
@pytest.mark.asyncio
async def test_bm25_tr_recall_with_plan(mock_plan, mock_bm25_tr):
result = await mock_bm25_tr.recall_tools(plan=mock_plan)
assert len(result) == 3
assert result[0].name == "PolynomialExpansion"
@pytest.mark.asyncio
async def test_bm25_tr_recall_no_plan(mock_plan, mock_bm25_tr):
result = await mock_bm25_tr.recall_tools(
context="conduct feature engineering, add new features on the dataset", plan=None
)
assert len(result) == 3
assert result[0].name == "PolynomialExpansion"
@pytest.mark.asyncio
async def test_bm25_recommend_tools(mock_bm25_tr):
result = await mock_bm25_tr.recommend_tools(context="conduct feature engineering, add new features on the dataset")
assert len(result) == 2 # web scraping tool should be filtered out at rank stage
assert result[0].name == "PolynomialExpansion"
@pytest.mark.asyncio
async def test_get_recommended_tool_info(mock_plan, mock_bm25_tr):
result = await mock_bm25_tr.get_recommended_tool_info(plan=mock_plan)
assert isinstance(result, str)
@pytest.mark.asyncio
async def test_tm_tr_recall_with_plan(mock_plan, mock_bm25_tr):
tr = TypeMatchToolRecommender(tools=["FillMissingValue", "PolynomialExpansion", "web scraping"])
result = await tr.recall_tools(plan=mock_plan)
assert len(result) == 1
assert result[0].name == "PolynomialExpansion"

View file

@ -1,7 +1,6 @@
import pytest
from metagpt.tools.tool_registry import ToolRegistry
from metagpt.tools.tool_type import ToolType
@pytest.fixture
@ -9,25 +8,11 @@ def tool_registry():
return ToolRegistry()
@pytest.fixture
def tool_registry_full():
return ToolRegistry(tool_types=ToolType)
# Test Initialization
def test_initialization(tool_registry):
assert isinstance(tool_registry, ToolRegistry)
assert tool_registry.tools == {}
assert tool_registry.tool_types == {}
assert tool_registry.tools_by_types == {}
# Test Initialization with tool types
def test_initialize_with_tool_types(tool_registry_full):
assert isinstance(tool_registry_full, ToolRegistry)
assert tool_registry_full.tools == {}
assert tool_registry_full.tools_by_types == {}
assert "data_preprocess" in tool_registry_full.tool_types
assert tool_registry.tools_by_tags == {}
class TestClassTool:
@ -72,31 +57,24 @@ def test_get_tool(tool_registry):
assert "description" in tool.schemas
# Similar tests for has_tool_type, get_tool_type, get_tools_by_type
def test_has_tool_type(tool_registry_full):
assert tool_registry_full.has_tool_type("data_preprocess")
assert not tool_registry_full.has_tool_type("NonexistentType")
def test_has_tool_tag(tool_registry):
tool_registry.register_tool(
"TestClassTool", "/path/to/tool", tool_source_object=TestClassTool, tags=["machine learning", "test"]
)
assert tool_registry.has_tool_tag("test")
assert not tool_registry.has_tool_tag("Non-existent tag")
def test_get_tool_type(tool_registry_full):
retrieved_type = tool_registry_full.get_tool_type("data_preprocess")
assert retrieved_type is not None
assert retrieved_type.name == "data_preprocess"
def test_get_tools_by_type(tool_registry):
tool_type_name = "TestType"
def test_get_tools_by_tag(tool_registry):
tool_tag_name = "Test Tag"
tool_name = "TestTool"
tool_path = "/path/to/tool"
tool_registry.register_tool(tool_name, tool_path, tool_type=tool_type_name, tool_source_object=TestClassTool)
tool_registry.register_tool(tool_name, tool_path, tags=[tool_tag_name], tool_source_object=TestClassTool)
tools_by_type = tool_registry.get_tools_by_type(tool_type_name)
assert tools_by_type is not None
assert tool_name in tools_by_type
tools_by_tag = tool_registry.get_tools_by_tag(tool_tag_name)
assert tools_by_tag is not None
assert tool_name in tools_by_tag
# Test case for when the tool type does not exist
def test_get_tools_by_nonexistent_type(tool_registry):
tools_by_type = tool_registry.get_tools_by_type("NonexistentType")
assert not tools_by_type
tools_by_tag_non_existent = tool_registry.get_tools_by_tag("Non-existent Tag")
assert not tools_by_tag_non_existent

View file

@ -9,14 +9,16 @@ from metagpt.utils.parse_html import WebPage
@pytest.mark.asyncio
@pytest.mark.parametrize(
"browser_type, url, urls",
"browser_type",
[
(WebBrowserEngineType.PLAYWRIGHT, "https://deepwisdom.ai", ("https://deepwisdom.ai",)),
(WebBrowserEngineType.SELENIUM, "https://deepwisdom.ai", ("https://deepwisdom.ai",)),
WebBrowserEngineType.PLAYWRIGHT,
WebBrowserEngineType.SELENIUM,
],
ids=["playwright", "selenium"],
)
async def test_scrape_web_page(browser_type, url, urls):
async def test_scrape_web_page(browser_type, http_server):
server, url = await http_server()
urls = [url, url, url]
browser = web_browser_engine.WebBrowserEngine(engine=browser_type)
result = await browser.run(url)
assert isinstance(result, WebPage)
@ -27,6 +29,7 @@ async def test_scrape_web_page(browser_type, url, urls):
assert isinstance(results, list)
assert len(results) == len(urls) + 1
assert all(("MetaGPT" in i.inner_text) for i in results)
await server.stop()
if __name__ == "__main__":

View file

@ -9,18 +9,28 @@ from metagpt.utils.parse_html import WebPage
@pytest.mark.asyncio
@pytest.mark.parametrize(
"browser_type, use_proxy, kwagrs, url, urls",
"browser_type, use_proxy, kwagrs,",
[
("chromium", {"proxy": True}, {}, "https://www.deepwisdom.ai", ("https://www.deepwisdom.ai",)),
("firefox", {}, {"ignore_https_errors": True}, "https://www.deepwisdom.ai", ("https://www.deepwisdom.ai",)),
("webkit", {}, {"ignore_https_errors": True}, "https://www.deepwisdom.ai", ("https://www.deepwisdom.ai",)),
("chromium", {"proxy": True}, {}),
(
"firefox",
{},
{"ignore_https_errors": True},
),
(
"webkit",
{},
{"ignore_https_errors": True},
),
],
ids=["chromium-normal", "firefox-normal", "webkit-normal"],
)
async def test_scrape_web_page(browser_type, use_proxy, kwagrs, url, urls, proxy, capfd):
async def test_scrape_web_page(browser_type, use_proxy, kwagrs, proxy, capfd, http_server):
server, url = await http_server()
urls = [url, url, url]
proxy_url = None
if use_proxy:
server, proxy_url = await proxy()
proxy_server, proxy_url = await proxy()
browser = web_browser_engine_playwright.PlaywrightWrapper(browser_type=browser_type, proxy=proxy_url, **kwagrs)
result = await browser.run(url)
assert isinstance(result, WebPage)
@ -32,8 +42,10 @@ async def test_scrape_web_page(browser_type, use_proxy, kwagrs, url, urls, proxy
assert len(results) == len(urls) + 1
assert all(("MetaGPT" in i.inner_text) for i in results)
if use_proxy:
server.close()
proxy_server.close()
await proxy_server.wait_closed()
assert "Proxy:" in capfd.readouterr().out
await server.stop()
if __name__ == "__main__":

View file

@ -1,6 +1,7 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import browsers
import pytest
@ -10,51 +11,48 @@ from metagpt.utils.parse_html import WebPage
@pytest.mark.asyncio
@pytest.mark.parametrize(
"browser_type, use_proxy, url, urls",
"browser_type, use_proxy,",
[
pytest.param(
"chrome",
True,
"https://deepwisdom.ai",
("https://deepwisdom.ai",),
False,
marks=pytest.mark.skipif(not browsers.get("chrome"), reason="chrome browser not found"),
),
pytest.param(
"firefox",
False,
"https://deepwisdom.ai",
("https://deepwisdom.ai",),
marks=pytest.mark.skipif(not browsers.get("firefox"), reason="firefox browser not found"),
),
pytest.param(
"edge",
False,
"https://deepwisdom.ai",
("https://deepwisdom.ai",),
marks=pytest.mark.skipif(not browsers.get("msedge"), reason="edge browser not found"),
),
],
ids=["chrome-normal", "firefox-normal", "edge-normal"],
)
async def test_scrape_web_page(browser_type, use_proxy, url, urls, proxy, capfd):
async def test_scrape_web_page(browser_type, use_proxy, proxy, capfd, http_server):
# Prerequisites
# firefox, chrome, Microsoft Edge
server, url = await http_server()
urls = [url, url, url]
proxy_url = None
if use_proxy:
server, proxy_url = await proxy()
proxy_server, proxy_url = await proxy()
browser = web_browser_engine_selenium.SeleniumWrapper(browser_type=browser_type, proxy=proxy_url)
result = await browser.run(url)
assert isinstance(result, WebPage)
assert "MetaGPT" in result.inner_text
if urls:
results = await browser.run(url, *urls)
assert isinstance(results, list)
assert len(results) == len(urls) + 1
assert all(("MetaGPT" in i.inner_text) for i in results)
results = await browser.run(url, *urls)
assert isinstance(results, list)
assert len(results) == len(urls) + 1
assert all(("MetaGPT" in i.inner_text) for i in results)
if use_proxy:
server.close()
assert "Proxy:" in capfd.readouterr().out
proxy_server.close()
await proxy_server.wait_closed()
assert "Proxy: localhost" in capfd.readouterr().out
await server.stop()
if __name__ == "__main__":

View file

@ -13,7 +13,6 @@ import uuid
from pathlib import Path
from typing import Any, Set
import aiofiles
import pytest
from pydantic import BaseModel
@ -125,9 +124,7 @@ class TestGetProjectRoot:
async def test_parse_data_exception(self, filename, want):
pathname = Path(__file__).parent.parent.parent / "data/output_parser" / filename
assert pathname.exists()
async with aiofiles.open(str(pathname), mode="r") as reader:
data = await reader.read()
data = await aread(filename=pathname)
result = OutputParser.parse_data(data=data)
assert want in result
@ -178,7 +175,7 @@ class TestGetProjectRoot:
],
)
def test_split_namespace(self, val, want):
res = split_namespace(val)
res = split_namespace(val, maxsplit=-1)
assert res == want
def test_read_json_file(self):
@ -198,12 +195,25 @@ class TestGetProjectRoot:
@pytest.mark.asyncio
async def test_read_write(self):
pathname = Path(__file__).parent / uuid.uuid4().hex / "test.tmp"
pathname = Path(__file__).parent / f"../../../workspace/unittest/{uuid.uuid4().hex}" / "test.tmp"
await awrite(pathname, "ABC")
data = await aread(pathname)
assert data == "ABC"
pathname.unlink(missing_ok=True)
@pytest.mark.asyncio
async def test_read_write_error_charset(self):
pathname = Path(__file__).parent / f"../../../workspace/unittest/{uuid.uuid4().hex}" / "test.txt"
content = "中国abc123\u27f6"
await awrite(filename=pathname, data=content)
data = await aread(filename=pathname)
assert data == content
content = "GB18030 是中国国家标准局发布的新一代中文字符集标准,是 GBK 的升级版,支持更广泛的字符范围。"
await awrite(filename=pathname, data=content, encoding="gb2312")
data = await aread(filename=pathname, encoding="utf-8")
assert data == content
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -10,15 +10,14 @@
import shutil
from pathlib import Path
import aiofiles
import pytest
from metagpt.utils.common import awrite
from metagpt.utils.git_repository import GitRepository
async def mock_file(filename, content=""):
async with aiofiles.open(str(filename), mode="w") as file:
await file.write(content)
await awrite(filename=filename, data=content)
async def mock_repo(local_path) -> (GitRepository, Path):

View file

@ -14,7 +14,7 @@ from metagpt.utils.mermaid import MMC1, mermaid_to_file
@pytest.mark.asyncio
@pytest.mark.parametrize("engine", ["nodejs", "ink"]) # TODO: playwright and pyppeteer
async def test_mermaid(engine, context):
async def test_mermaid(engine, context, mermaid_mocker):
# nodejs prerequisites: npm install -g @mermaid-js/mermaid-cli
# ink prerequisites: connected to internet
# playwright prerequisites: playwright install --with-deps chromium

View file

@ -211,6 +211,11 @@ value
output = repair_invalid_json(output, "Expecting ',' delimiter: line 4 column 1")
assert output == target_output
raw_output = '{"key": "url "http" \\"https\\" "}'
target_output = '{"key": "url \\"http\\" \\"https\\" "}'
output = repair_invalid_json(raw_output, "Expecting ',' delimiter: line 1 column 15 (char 14)")
assert output == target_output
def test_retry_parse_json_text():
from metagpt.utils.repair_llm_raw_output import retry_parse_json_text

View file

@ -0,0 +1,25 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import uuid
from pathlib import Path
import pytest
from metagpt.utils.repo_to_markdown import repo_to_markdown
@pytest.mark.parametrize(
["repo_path", "output"],
[(Path(__file__).parent.parent, Path(__file__).parent.parent.parent / f"workspace/unittest/{uuid.uuid4().hex}.md")],
)
@pytest.mark.asyncio
async def test_repo_to_markdown(repo_path: Path, output: Path):
markdown = await repo_to_markdown(repo_path=repo_path, output=output)
assert output.exists()
assert markdown
output.unlink(missing_ok=True)
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -9,10 +9,10 @@ import uuid
from pathlib import Path
import aioboto3
import aiofiles
import pytest
from metagpt.config2 import Config
from metagpt.configs.s3_config import S3Config
from metagpt.utils.common import aread
from metagpt.utils.s3 import S3
@ -30,6 +30,14 @@ async def test_s3(mocker):
mock_client.__aenter__.return_value = mock_client
mock_client.__aexit__.return_value = None
mocker.patch.object(aioboto3.Session, "client", return_value=mock_client)
mock_config = mocker.Mock()
mock_config.s3 = S3Config(
access_key="mock_access_key",
secret_key="mock_secret_key",
endpoint="http://mock.endpoint",
bucket="mock_bucket",
)
mocker.patch.object(Config, "default", return_value=mock_config)
# Prerequisites
s3 = Config.default().s3
@ -37,7 +45,7 @@ async def test_s3(mocker):
conn = S3(s3)
object_name = "unittest.bak"
await conn.upload_file(bucket=s3.bucket, local_path=__file__, object_name=object_name)
pathname = (Path(__file__).parent / uuid.uuid4().hex).with_suffix(".bak")
pathname = (Path(__file__).parent / "../../../workspace/unittest" / uuid.uuid4().hex).with_suffix(".bak")
pathname.unlink(missing_ok=True)
await conn.download_file(bucket=s3.bucket, object_name=object_name, local_path=str(pathname))
assert pathname.exists()
@ -45,8 +53,7 @@ async def test_s3(mocker):
assert url
bin_data = await conn.get_object(bucket=s3.bucket, object_name=object_name)
assert bin_data
async with aiofiles.open(__file__, mode="r", encoding="utf-8") as reader:
data = await reader.read()
data = await aread(filename=__file__)
res = await conn.cache(data, ".bak", "script")
assert "http" in res
@ -60,8 +67,6 @@ async def test_s3(mocker):
except Exception:
pass
await reader.close()
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -6,7 +6,7 @@
import nbformat
import pytest
from metagpt.actions.ci.execute_nb_code import ExecuteNbCode
from metagpt.actions.di.execute_nb_code import ExecuteNbCode
from metagpt.utils.common import read_json_file
from metagpt.utils.save_code import DATA_PATH, save_code_file

View file

@ -22,7 +22,7 @@ def _paragraphs(n):
@pytest.mark.parametrize(
"msgs, model_name, system_text, reserved, expected",
[
(_msgs(), "gpt-3.5-turbo", "System", 1500, 1),
(_msgs(), "gpt-3.5-turbo-0613", "System", 1500, 1),
(_msgs(), "gpt-3.5-turbo-16k", "System", 3000, 6),
(_msgs(), "gpt-3.5-turbo-16k", "Hello," * 1000, 3000, 5),
(_msgs(), "gpt-4", "System", 2000, 3),
@ -32,21 +32,23 @@ def _paragraphs(n):
],
)
def test_reduce_message_length(msgs, model_name, system_text, reserved, expected):
assert len(reduce_message_length(msgs, model_name, system_text, reserved)) / (len("Hello,")) / 1000 == expected
length = len(reduce_message_length(msgs, model_name, system_text, reserved)) / (len("Hello,")) / 1000
assert length == expected
@pytest.mark.parametrize(
"text, prompt_template, model_name, system_text, reserved, expected",
[
(" ".join("Hello World." for _ in range(1000)), "Prompt: {}", "gpt-3.5-turbo", "System", 1500, 2),
(" ".join("Hello World." for _ in range(1000)), "Prompt: {}", "gpt-3.5-turbo-0613", "System", 1500, 2),
(" ".join("Hello World." for _ in range(1000)), "Prompt: {}", "gpt-3.5-turbo-16k", "System", 3000, 1),
(" ".join("Hello World." for _ in range(4000)), "Prompt: {}", "gpt-4", "System", 2000, 2),
(" ".join("Hello World." for _ in range(8000)), "Prompt: {}", "gpt-4-32k", "System", 4000, 1),
(" ".join("Hello World" for _ in range(8000)), "Prompt: {}", "gpt-3.5-turbo-0613", "System", 1000, 8),
],
)
def test_generate_prompt_chunk(text, prompt_template, model_name, system_text, reserved, expected):
ret = list(generate_prompt_chunk(text, prompt_template, model_name, system_text, reserved))
assert len(ret) == expected
chunk = len(list(generate_prompt_chunk(text, prompt_template, model_name, system_text, reserved)))
assert chunk == expected
@pytest.mark.parametrize(

View file

@ -0,0 +1,64 @@
from pathlib import Path
from typing import List
import pytest
from metagpt.utils.tree import _print_tree, tree
@pytest.mark.parametrize(
("root", "rules"),
[
(str(Path(__file__).parent / "../.."), None),
(str(Path(__file__).parent / "../.."), str(Path(__file__).parent / "../../../.gitignore")),
],
)
def test_tree(root: str, rules: str):
v = tree(root=root, gitignore=rules)
assert v
@pytest.mark.parametrize(
("root", "rules"),
[
(str(Path(__file__).parent / "../.."), None),
(str(Path(__file__).parent / "../.."), str(Path(__file__).parent / "../../../.gitignore")),
],
)
def test_tree_command(root: str, rules: str):
v = tree(root=root, gitignore=rules, run_command=True)
assert v
@pytest.mark.parametrize(
("tree", "want"),
[
({"a": {"b": {}, "c": {}}}, ["a", "+-- b", "+-- c"]),
({"a": {"b": {}, "c": {"d": {}}}}, ["a", "+-- b", "+-- c", " +-- d"]),
(
{"a": {"b": {"e": {"f": {}, "g": {}}}, "c": {"d": {}}}},
["a", "+-- b", "| +-- e", "| +-- f", "| +-- g", "+-- c", " +-- d"],
),
(
{"h": {"a": {"b": {"e": {"f": {}, "g": {}}}, "c": {"d": {}}}, "i": {}}},
[
"h",
"+-- a",
"| +-- b",
"| | +-- e",
"| | +-- f",
"| | +-- g",
"| +-- c",
"| +-- d",
"+-- i",
],
),
],
)
def test__print_tree(tree: dict, want: List[str]):
v = _print_tree(tree)
assert v == want
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -0,0 +1,45 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2024/1/4
@Author : mashenquan
@File : test_visual_graph_repo.py
@Desc : Unit tests for testing and demonstrating the usage of VisualDiGraphRepo.
"""
import re
from pathlib import Path
import pytest
from metagpt.utils.common import remove_affix, split_namespace
from metagpt.utils.visual_graph_repo import VisualDiGraphRepo
@pytest.mark.asyncio
async def test_visual_di_graph_repo(context, mocker):
filename = Path(__file__).parent / "../../data/graph_db/networkx.sequence_view.json"
repo = await VisualDiGraphRepo.load_from(filename=filename)
class_view = await repo.get_mermaid_class_view()
assert class_view
await context.repo.resources.graph_repo.save(filename="class_view.md", content=f"```mermaid\n{class_view}\n```\n")
sequence_views = await repo.get_mermaid_sequence_views()
assert sequence_views
for ns, sqv in sequence_views:
filename = re.sub(r"[:/\\\.]+", "_", ns) + ".sequence_view.md"
sqv = sqv.strip(" `")
await context.repo.resources.graph_repo.save(filename=filename, content=f"```mermaid\n{sqv}\n```\n")
sequence_view_vers = await repo.get_mermaid_sequence_view_versions()
assert sequence_view_vers
for ns, sqv in sequence_view_vers:
ver, sqv = split_namespace(sqv)
filename = re.sub(r"[:/\\\.]+", "_", ns) + f".{ver}.sequence_view_ver.md"
sqv = remove_affix(sqv).strip(" `")
await context.repo.resources.graph_repo.save(filename=filename, content=f"```mermaid\n{sqv}\n```\n")
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -10,6 +10,7 @@ class MockAioResponse:
check_funcs: dict[tuple[str, str], Callable[[dict], str]] = {}
rsp_cache: dict[str, str] = {}
name = "aiohttp"
status = 200
def __init__(self, session, method, url, **kwargs) -> None:
fn = self.check_funcs.get((method, url))
@ -22,6 +23,7 @@ class MockAioResponse:
async def __aenter__(self):
if self.response:
await self.response.__aenter__()
self.status = self.response.status
elif self.mng:
self.response = await self.mng.__aenter__()
return self
@ -41,6 +43,17 @@ class MockAioResponse:
self.rsp_cache[self.key] = data
return data
@property
def content(self):
return self
async def read(self):
if self.key in self.rsp_cache:
return eval(self.rsp_cache[self.key])
data = await self.response.content.read()
self.rsp_cache[self.key] = str(data)
return data
def raise_for_status(self):
if self.response:
self.response.raise_for_status()

View file

@ -3,8 +3,9 @@ from typing import Optional, Union
from metagpt.config2 import config
from metagpt.configs.llm_config import LLMType
from metagpt.logs import log_llm_stream, logger
from metagpt.logs import logger
from metagpt.provider.azure_openai_api import AzureOpenAILLM
from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA
from metagpt.provider.openai_api import OpenAILLM
from metagpt.schema import Message
@ -24,31 +25,21 @@ class MockLLM(OriginalLLM):
async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str:
"""Overwrite original acompletion_text to cancel retry"""
if stream:
resp = self._achat_completion_stream(messages, timeout=timeout)
collected_messages = []
async for i in resp:
log_llm_stream(i)
collected_messages.append(i)
full_reply_content = "".join(collected_messages)
usage = self._calc_usage(messages, full_reply_content)
self._update_costs(usage)
return full_reply_content
resp = await self._achat_completion_stream(messages, timeout=timeout)
return resp
rsp = await self._achat_completion(messages, timeout=timeout)
return self.get_choice_text(rsp)
async def original_aask(
self,
msg: str,
msg: Union[str, list[dict[str, str]]],
system_msgs: Optional[list[str]] = None,
format_msgs: Optional[list[dict[str, str]]] = None,
images: Optional[Union[str, list[str]]] = None,
timeout=3,
stream=True,
):
"""A copy of metagpt.provider.base_llm.BaseLLM.aask, we can't use super().aask because it will be mocked"""
) -> str:
if system_msgs:
message = self._system_msgs(system_msgs)
else:
@ -57,7 +48,11 @@ class MockLLM(OriginalLLM):
message = []
if format_msgs:
message.extend(format_msgs)
message.append(self._user_msg(msg, images=images))
if isinstance(msg, str):
message.append(self._user_msg(msg, images=images))
else:
message.extend(msg)
logger.debug(message)
rsp = await self.acompletion_text(message, stream=stream, timeout=timeout)
return rsp
@ -76,19 +71,27 @@ class MockLLM(OriginalLLM):
A copy of metagpt.provider.openai_api.OpenAILLM.aask_code, we can't use super().aask because it will be mocked.
Since openai_api.OpenAILLM.aask_code is different from base_llm.BaseLLM.aask_code, we use the former.
"""
if "tools" not in kwargs:
configs = {"tools": [{"type": "function", "function": GENERAL_FUNCTION_SCHEMA}]}
kwargs.update(configs)
rsp = await self._achat_completion_function(messages, **kwargs)
return self.get_choice_function_arguments(rsp)
async def aask(
self,
msg: str,
msg: Union[str, list[dict[str, str]]],
system_msgs: Optional[list[str]] = None,
format_msgs: Optional[list[dict[str, str]]] = None,
images: Optional[Union[str, list[str]]] = None,
timeout=3,
stream=True,
) -> str:
msg_key = msg # used to identify it a message has been called before
# used to identify it a message has been called before
if isinstance(msg, list):
msg_key = "#MSG_SEP#".join([m["content"] for m in msg])
else:
msg_key = msg
if system_msgs:
joined_system_msg = "#MSG_SEP#".join(system_msgs) + "#SYSTEM_MSG_END#"
msg_key = joined_system_msg + msg_key
@ -101,8 +104,7 @@ class MockLLM(OriginalLLM):
return rsp
async def aask_code(self, messages: Union[str, Message, list[dict]], **kwargs) -> dict:
messages = self._process_message(messages)
msg_key = json.dumps(messages, ensure_ascii=False)
msg_key = json.dumps(self.format_msg(messages), ensure_ascii=False)
rsp = await self._mock_rsp(msg_key, self.original_aask_code, messages, **kwargs)
return rsp

View file

@ -1,7 +0,0 @@
llm:
api_type: "spark"
app_id: "xxx"
api_key: "xxx"
api_secret: "xxx"
domain: "generalv2"
base_url: "wss://spark-api.xf-yun.com/v3.1/chat"