mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-20 15:38:09 +02:00
Merge branch 'geekan:main' into feat_st_game
This commit is contained in:
commit
64350d2c6d
303 changed files with 9155 additions and 3587 deletions
|
|
@ -12,8 +12,10 @@ import logging
|
|||
import os
|
||||
import re
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
import aiohttp.web
|
||||
import pytest
|
||||
|
||||
from metagpt.const import DEFAULT_WORKSPACE_ROOT, TEST_DATA_PATH
|
||||
|
|
@ -111,12 +113,13 @@ def proxy():
|
|||
while not reader.at_eof():
|
||||
writer.write(await reader.read(2048))
|
||||
writer.close()
|
||||
await writer.wait_closed()
|
||||
|
||||
async def handle_client(reader, writer):
|
||||
data = await reader.readuntil(b"\r\n\r\n")
|
||||
print(f"Proxy: {data}") # checking with capfd fixture
|
||||
infos = pattern.match(data)
|
||||
host, port = infos.group("host"), infos.group("port")
|
||||
print(f"Proxy: {host}") # checking with capfd fixture
|
||||
port = int(port) if port else 80
|
||||
remote_reader, remote_writer = await asyncio.open_connection(host, port)
|
||||
if data.startswith(b"CONNECT"):
|
||||
|
|
@ -171,9 +174,8 @@ def new_filename(mocker):
|
|||
yield mocker
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def search_rsp_cache():
|
||||
rsp_cache_file_path = TEST_DATA_PATH / "search_rsp_cache.json" # read repo-provided
|
||||
def _rsp_cache(name):
|
||||
rsp_cache_file_path = TEST_DATA_PATH / f"{name}.json" # read repo-provided
|
||||
if os.path.exists(rsp_cache_file_path):
|
||||
with open(rsp_cache_file_path, "r") as f1:
|
||||
rsp_cache_json = json.load(f1)
|
||||
|
|
@ -184,6 +186,16 @@ def search_rsp_cache():
|
|||
json.dump(rsp_cache_json, f2, indent=4, ensure_ascii=False)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def search_rsp_cache():
|
||||
yield from _rsp_cache("search_rsp_cache")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def mermaid_rsp_cache():
|
||||
yield from _rsp_cache("mermaid_rsp_cache")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def aiohttp_mocker(mocker):
|
||||
MockResponse = type("MockResponse", (MockAioResponse,), {})
|
||||
|
|
@ -231,3 +243,40 @@ def search_engine_mocker(aiohttp_mocker, curl_cffi_mocker, httplib2_mocker, sear
|
|||
aiohttp_mocker.rsp_cache = httplib2_mocker.rsp_cache = curl_cffi_mocker.rsp_cache = search_rsp_cache
|
||||
aiohttp_mocker.check_funcs = httplib2_mocker.check_funcs = curl_cffi_mocker.check_funcs = check_funcs
|
||||
yield check_funcs
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def http_server():
|
||||
async def handler(request):
|
||||
return aiohttp.web.Response(
|
||||
text="""<!DOCTYPE html><html lang="en"><head><meta charset="UTF-8">
|
||||
<title>MetaGPT</title></head><body><h1>MetaGPT</h1></body></html>""",
|
||||
content_type="text/html",
|
||||
)
|
||||
|
||||
async def start():
|
||||
server = aiohttp.web.Server(handler)
|
||||
runner = aiohttp.web.ServerRunner(server)
|
||||
await runner.setup()
|
||||
site = aiohttp.web.TCPSite(runner, "127.0.0.1", 0)
|
||||
await site.start()
|
||||
_, port, *_ = site._server.sockets[0].getsockname()
|
||||
return site, f"http://127.0.0.1:{port}"
|
||||
|
||||
return start
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mermaid_mocker(aiohttp_mocker, mermaid_rsp_cache):
|
||||
check_funcs: dict[tuple[str, str], Callable[[dict], str]] = {}
|
||||
aiohttp_mocker.rsp_cache = mermaid_rsp_cache
|
||||
aiohttp_mocker.check_funcs = check_funcs
|
||||
yield check_funcs
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def git_dir():
|
||||
"""Fixture to get the unittest directory."""
|
||||
git_dir = Path(__file__).parent / f"unittest/{uuid.uuid4().hex}"
|
||||
git_dir.mkdir(parents=True, exist_ok=True)
|
||||
return git_dir
|
||||
|
|
|
|||
1
tests/data/graph_db/networkx.class_view.json
Normal file
1
tests/data/graph_db/networkx.class_view.json
Normal file
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
1
tests/data/graph_db/networkx.sequence_view.json
Normal file
1
tests/data/graph_db/networkx.sequence_view.json
Normal file
File diff suppressed because one or more lines are too long
Binary file not shown.
Binary file not shown.
|
|
@ -149,7 +149,7 @@ sequenceDiagram
|
|||
|
||||
The requirement analysis suggests the need for a clean and intuitive interface. Since we are using a command-line interface, we need to ensure that the text-based UI is as user-friendly as possible. Further clarification on whether a graphical user interface (GUI) is expected in the future would be helpful for planning the extendability of the game."""
|
||||
|
||||
TASKS_SAMPLE = """
|
||||
TASK_SAMPLE = """
|
||||
## Required Python packages
|
||||
|
||||
- random==2.2.1
|
||||
|
|
@ -345,7 +345,7 @@ REFINED_DESIGN_JSON = {
|
|||
"Anything UNCLEAR": "",
|
||||
}
|
||||
|
||||
REFINED_TASKS_JSON = {
|
||||
REFINED_TASK_JSON = {
|
||||
"Required Python packages": ["random==2.2.1", "Tkinter==8.6"],
|
||||
"Required Other language third-party packages": ["No third-party dependencies required"],
|
||||
"Refined Logic Analysis": [
|
||||
|
|
@ -373,7 +373,14 @@ REFINED_TASKS_JSON = {
|
|||
}
|
||||
|
||||
CODE_PLAN_AND_CHANGE_SAMPLE = {
|
||||
"Code Plan And Change": '\n1. Plan for gui.py: Develop the GUI using Tkinter to replace the command-line interface. Start by setting up the main window and event handling. Then, add widgets for displaying the game status, results, and feedback. Implement interactive elements for difficulty selection and visualize the guess history. Finally, create animations for guess feedback and ensure responsiveness across different screen sizes.\n```python\nclass GUI:\n- pass\n+ def __init__(self):\n+ self.setup_window()\n+\n+ def setup_window(self):\n+ # Initialize the main window using Tkinter\n+ pass\n+\n+ def bind_events(self):\n+ # Bind button clicks and other events\n+ pass\n+\n+ def update_feedback(self, message: str):\n+ # Update the feedback label with the given message\n+ pass\n+\n+ def update_attempts(self, attempts: int):\n+ # Update the attempts label with the number of attempts\n+ pass\n+\n+ def update_history(self, history: list):\n+ # Update the history view with the list of past guesses\n+ pass\n+\n+ def show_difficulty_selector(self):\n+ # Show buttons or a dropdown for difficulty selection\n+ pass\n+\n+ def animate_guess_result(self, correct: bool):\n+ # Trigger an animation for correct or incorrect guesses\n+ pass\n```\n\n2. Plan for main.py: Modify the main.py to initialize the GUI and start the event-driven game loop. Ensure that the GUI is the primary interface for user interaction.\n```python\nclass Main:\n def main(self):\n- user_interface = UI()\n- user_interface.start()\n+ graphical_user_interface = GUI()\n+ graphical_user_interface.setup_window()\n+ graphical_user_interface.bind_events()\n+ # Start the Tkinter main loop\n+ pass\n\n if __name__ == "__main__":\n main_instance = Main()\n main_instance.main()\n```\n\n3. Plan for ui.py: Refactor ui.py to work with the new GUI class. Remove command-line interactions and delegate display and input tasks to the GUI.\n```python\nclass UI:\n- def display_message(self, message: str):\n- print(message)\n+\n+ def display_message(self, message: str):\n+ # This method will now pass the message to the GUI to display\n+ pass\n\n- def get_user_input(self, prompt: str) -> str:\n- return input(prompt)\n+\n+ def get_user_input(self, prompt: str) -> str:\n+ # This method will now trigger the GUI to get user input\n+ pass\n\n- def show_attempts(self, attempts: int):\n- print(f"Number of attempts: {attempts}")\n+\n+ def show_attempts(self, attempts: int):\n+ # This method will now update the GUI with the number of attempts\n+ pass\n\n- def show_history(self, history: list):\n- print("Guess history:")\n- for guess in history:\n- print(guess)\n+\n+ def show_history(self, history: list):\n+ # This method will now update the GUI with the guess history\n+ pass\n```\n\n4. Plan for game.py: Ensure game.py remains mostly unchanged as it contains the core game logic. However, make minor adjustments if necessary to integrate with the new GUI.\n```python\nclass Game:\n # No changes required for now\n```\n'
|
||||
"Development Plan": [
|
||||
"Develop the GUI using Tkinter to replace the command-line interface. Start by setting up the main window and event handling. Then, add widgets for displaying the game status, results, and feedback. Implement interactive elements for difficulty selection and visualize the guess history. Finally, create animations for guess feedback and ensure responsiveness across different screen sizes.",
|
||||
"Modify the main.py to initialize the GUI and start the event-driven game loop. Ensure that the GUI is the primary interface for user interaction.",
|
||||
],
|
||||
"Incremental Change": [
|
||||
"""```diff\nclass GUI:\n- pass\n+ def __init__(self):\n+ self.setup_window()\n+\n+ def setup_window(self):\n+ # Initialize the main window using Tkinter\n+ pass\n+\n+ def bind_events(self):\n+ # Bind button clicks and other events\n+ pass\n+\n+ def update_feedback(self, message: str):\n+ # Update the feedback label with the given message\n+ pass\n+\n+ def update_attempts(self, attempts: int):\n+ # Update the attempts label with the number of attempts\n+ pass\n+\n+ def update_history(self, history: list):\n+ # Update the history view with the list of past guesses\n+ pass\n+\n+ def show_difficulty_selector(self):\n+ # Show buttons or a dropdown for difficulty selection\n+ pass\n+\n+ def animate_guess_result(self, correct: bool):\n+ # Trigger an animation for correct or incorrect guesses\n+ pass\n```""",
|
||||
"""```diff\nclass Main:\n def main(self):\n- user_interface = UI()\n- user_interface.start()\n+ graphical_user_interface = GUI()\n+ graphical_user_interface.setup_window()\n+ graphical_user_interface.bind_events()\n+ # Start the Tkinter main loop\n+ pass\n\n if __name__ == "__main__":\n main_instance = Main()\n main_instance.main()\n```\n\n3. Plan for ui.py: Refactor ui.py to work with the new GUI class. Remove command-line interactions and delegate display and input tasks to the GUI.\n```python\nclass UI:\n- def display_message(self, message: str):\n- print(message)\n+\n+ def display_message(self, message: str):\n+ # This method will now pass the message to the GUI to display\n+ pass\n\n- def get_user_input(self, prompt: str) -> str:\n- return input(prompt)\n+\n+ def get_user_input(self, prompt: str) -> str:\n+ # This method will now trigger the GUI to get user input\n+ pass\n\n- def show_attempts(self, attempts: int):\n- print(f"Number of attempts: {attempts}")\n+\n+ def show_attempts(self, attempts: int):\n+ # This method will now update the GUI with the number of attempts\n+ pass\n\n- def show_history(self, history: list):\n- print("Guess history:")\n- for guess in history:\n- print(guess)\n+\n+ def show_history(self, history: list):\n+ # This method will now update the GUI with the guess history\n+ pass\n```\n\n4. Plan for game.py: Ensure game.py remains mostly unchanged as it contains the core game logic. However, make minor adjustments if necessary to integrate with the new GUI.\n```python\nclass Game:\n # No changes required for now\n```\n""",
|
||||
],
|
||||
}
|
||||
|
||||
REFINED_CODE_INPUT_SAMPLE = """
|
||||
|
|
|
|||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
4
tests/data/mermaid_rsp_cache.json
Normal file
4
tests/data/mermaid_rsp_cache.json
Normal file
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
|
|
@ -1,51 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Date : 1/11/2024 8:51 PM
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Desc :
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.actions.ci.debug_code import DebugCode
|
||||
from metagpt.schema import Message
|
||||
|
||||
ErrorStr = """Tested passed:
|
||||
|
||||
Tests failed:
|
||||
assert sort_array([1, 5, 2, 3, 4]) == [1, 2, 3, 4, 5] # output: [1, 2, 4, 3, 5]
|
||||
"""
|
||||
|
||||
CODE = """
|
||||
def sort_array(arr):
|
||||
# Helper function to count the number of ones in the binary representation
|
||||
def count_ones(n):
|
||||
return bin(n).count('1')
|
||||
|
||||
# Sort the array using a custom key function
|
||||
# The key function returns a tuple (number of ones, value) for each element
|
||||
# This ensures that if two elements have the same number of ones, they are sorted by their value
|
||||
sorted_arr = sorted(arr, key=lambda x: (count_ones(x), x))
|
||||
|
||||
return sorted_arr
|
||||
```
|
||||
"""
|
||||
|
||||
DebugContext = '''Solve the problem in Python:
|
||||
def sort_array(arr):
|
||||
"""
|
||||
In this Kata, you have to sort an array of non-negative integers according to
|
||||
number of ones in their binary representation in ascending order.
|
||||
For similar number of ones, sort based on decimal value.
|
||||
|
||||
It must be implemented like this:
|
||||
>>> sort_array([1, 5, 2, 3, 4]) == [1, 2, 3, 4, 5]
|
||||
>>> sort_array([-2, -3, -4, -5, -6]) == [-6, -5, -4, -3, -2]
|
||||
>>> sort_array([1, 0, 2, 3, 4]) [0, 1, 2, 3, 4]
|
||||
"""
|
||||
'''
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_debug_code():
|
||||
debug_context = Message(content=DebugContext)
|
||||
new_code = await DebugCode().run(context=debug_context, code=CODE, runtime_result=ErrorStr)
|
||||
assert "def sort_array(arr)" in new_code["code"]
|
||||
|
|
@ -1,46 +0,0 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.actions.ci.ml_action import WriteCodeWithToolsML
|
||||
from metagpt.schema import Plan, Task
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_code_with_tools():
|
||||
write_code_ml = WriteCodeWithToolsML()
|
||||
|
||||
task_map = {
|
||||
"1": Task(
|
||||
task_id="1",
|
||||
instruction="随机生成一个pandas DataFrame数据集",
|
||||
task_type="other",
|
||||
dependent_task_ids=[],
|
||||
code="""
|
||||
import pandas as pd
|
||||
df = pd.DataFrame({
|
||||
'a': [1, 2, 3, 4, 5],
|
||||
'b': [1.1, 2.2, 3.3, 4.4, np.nan],
|
||||
'c': ['aa', 'bb', 'cc', 'dd', 'ee'],
|
||||
'd': [1, 2, 3, 4, 5]
|
||||
})
|
||||
""",
|
||||
is_finished=True,
|
||||
),
|
||||
"2": Task(
|
||||
task_id="2",
|
||||
instruction="对数据集进行数据清洗",
|
||||
task_type="data_preprocess",
|
||||
dependent_task_ids=["1"],
|
||||
),
|
||||
}
|
||||
plan = Plan(
|
||||
goal="构造数据集并进行数据清洗",
|
||||
tasks=list(task_map.values()),
|
||||
task_map=task_map,
|
||||
current_task_id="2",
|
||||
)
|
||||
column_info = ""
|
||||
|
||||
_, code_with_ml = await write_code_ml.run([], plan, column_info)
|
||||
code_with_ml = code_with_ml["code"]
|
||||
assert len(code_with_ml) > 0
|
||||
print(code_with_ml)
|
||||
|
|
@ -1,324 +0,0 @@
|
|||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.actions.ci.execute_nb_code import ExecuteNbCode
|
||||
from metagpt.actions.ci.write_analysis_code import (
|
||||
WriteCodeWithoutTools,
|
||||
WriteCodeWithTools,
|
||||
)
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import Message, Plan, Task
|
||||
from metagpt.strategy.planner import STRUCTURAL_CONTEXT
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_code_by_list_plan():
|
||||
write_code = WriteCodeWithoutTools()
|
||||
execute_code = ExecuteNbCode()
|
||||
messages = []
|
||||
plan = ["随机生成一个pandas DataFrame时间序列", "绘制这个时间序列的直方图", "回顾已完成的任务", "求均值", "总结"]
|
||||
for task in plan:
|
||||
print(f"\n任务: {task}\n\n")
|
||||
messages.append(Message(task, role="assistant"))
|
||||
code = await write_code.run(messages)
|
||||
if task.startswith(("回顾", "总结")):
|
||||
assert code["language"] == "markdown"
|
||||
else:
|
||||
assert code["language"] == "python"
|
||||
messages.append(Message(code["code"], role="assistant"))
|
||||
assert len(code) > 0
|
||||
output, _ = await execute_code.run(**code)
|
||||
print(f"\n[Output]: 任务{task}的执行结果是: \n{output}\n")
|
||||
messages.append(output)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_recommendation():
|
||||
task = "clean and preprocess the data"
|
||||
available_tools = {
|
||||
"FillMissingValue": "Filling missing values",
|
||||
"SplitBins": "Bin continuous data into intervals and return the bin identifier encoded as an integer value",
|
||||
}
|
||||
write_code = WriteCodeWithTools()
|
||||
tools = await write_code._recommend_tool(task, available_tools)
|
||||
|
||||
assert len(tools) == 1
|
||||
assert "FillMissingValue" in tools
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_code_with_tools():
|
||||
write_code = WriteCodeWithTools()
|
||||
|
||||
requirement = "构造数据集并进行数据清洗"
|
||||
task_map = {
|
||||
"1": Task(
|
||||
task_id="1",
|
||||
instruction="随机生成一个pandas DataFrame数据集",
|
||||
task_type="other",
|
||||
dependent_task_ids=[],
|
||||
code="""
|
||||
import pandas as pd
|
||||
df = pd.DataFrame({
|
||||
'a': [1, 2, 3, 4, 5],
|
||||
'b': [1.1, 2.2, 3.3, 4.4, np.nan],
|
||||
'c': ['aa', 'bb', 'cc', 'dd', 'ee'],
|
||||
'd': [1, 2, 3, 4, 5]
|
||||
})
|
||||
""",
|
||||
is_finished=True,
|
||||
),
|
||||
"2": Task(
|
||||
task_id="2",
|
||||
instruction="对数据集进行数据清洗",
|
||||
task_type="data_preprocess",
|
||||
dependent_task_ids=["1"],
|
||||
),
|
||||
}
|
||||
plan = Plan(
|
||||
goal="构造数据集并进行数据清洗",
|
||||
tasks=list(task_map.values()),
|
||||
task_map=task_map,
|
||||
current_task_id="2",
|
||||
)
|
||||
|
||||
context = STRUCTURAL_CONTEXT.format(
|
||||
user_requirement=requirement,
|
||||
context=plan.context,
|
||||
tasks=list(task_map.values()),
|
||||
current_task=plan.current_task.model_dump_json(),
|
||||
)
|
||||
context_msg = [Message(content=context, role="user")]
|
||||
|
||||
code = await write_code.run(context_msg, plan)
|
||||
code = code["code"]
|
||||
assert len(code) > 0
|
||||
print(code)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_code_to_correct_error():
|
||||
structural_context = """
|
||||
## User Requirement
|
||||
read a dataset test.csv and print its head
|
||||
## Current Plan
|
||||
[
|
||||
{
|
||||
"task_id": "1",
|
||||
"dependent_task_ids": [],
|
||||
"instruction": "import pandas and load the dataset from 'test.csv'.",
|
||||
"task_type": "",
|
||||
"code": "",
|
||||
"result": "",
|
||||
"is_finished": false
|
||||
},
|
||||
{
|
||||
"task_id": "2",
|
||||
"dependent_task_ids": [
|
||||
"1"
|
||||
],
|
||||
"instruction": "Print the head of the dataset to display the first few rows.",
|
||||
"task_type": "",
|
||||
"code": "",
|
||||
"result": "",
|
||||
"is_finished": false
|
||||
}
|
||||
]
|
||||
## Current Task
|
||||
{"task_id": "1", "dependent_task_ids": [], "instruction": "import pandas and load the dataset from 'test.csv'.", "task_type": "", "code": "", "result": "", "is_finished": false}
|
||||
"""
|
||||
wrong_code = """import pandas as pd\ndata = pd.read_excel('test.csv')\ndata""" # use read_excel to read a csv
|
||||
error = """
|
||||
Traceback (most recent call last):
|
||||
File "<stdin>", line 2, in <module>
|
||||
File "/Users/gary/miniconda3/envs/py39_scratch/lib/python3.9/site-packages/pandas/io/excel/_base.py", line 478, in read_excel
|
||||
io = ExcelFile(io, storage_options=storage_options, engine=engine)
|
||||
File "/Users/gary/miniconda3/envs/py39_scratch/lib/python3.9/site-packages/pandas/io/excel/_base.py", line 1500, in __init__
|
||||
raise ValueError(
|
||||
ValueError: Excel file format cannot be determined, you must specify an engine manually.
|
||||
"""
|
||||
context = [
|
||||
Message(content=structural_context, role="user"),
|
||||
Message(content=wrong_code, role="assistant"),
|
||||
Message(content=error, role="user"),
|
||||
]
|
||||
new_code = await WriteCodeWithoutTools().run(context=context)
|
||||
new_code = new_code["code"]
|
||||
print(new_code)
|
||||
assert "read_csv" in new_code # should correct read_excel to read_csv
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_code_reuse_code_simple():
|
||||
structural_context = """
|
||||
## User Requirement
|
||||
read a dataset test.csv and print its head
|
||||
## Current Plan
|
||||
[
|
||||
{
|
||||
"task_id": "1",
|
||||
"dependent_task_ids": [],
|
||||
"instruction": "import pandas and load the dataset from 'test.csv'.",
|
||||
"task_type": "",
|
||||
"code": "import pandas as pd\ndata = pd.read_csv('test.csv')",
|
||||
"result": "",
|
||||
"is_finished": true
|
||||
},
|
||||
{
|
||||
"task_id": "2",
|
||||
"dependent_task_ids": [
|
||||
"1"
|
||||
],
|
||||
"instruction": "Print the head of the dataset to display the first few rows.",
|
||||
"task_type": "",
|
||||
"code": "",
|
||||
"result": "",
|
||||
"is_finished": false
|
||||
}
|
||||
]
|
||||
## Current Task
|
||||
{"task_id": "2", "dependent_task_ids": ["1"], "instruction": "Print the head of the dataset to display the first few rows.", "task_type": "", "code": "", "result": "", "is_finished": false}
|
||||
"""
|
||||
context = [
|
||||
Message(content=structural_context, role="user"),
|
||||
]
|
||||
code = await WriteCodeWithoutTools().run(context=context)
|
||||
code = code["code"]
|
||||
print(code)
|
||||
assert "pandas" not in code and "read_csv" not in code # should reuse import and read statement from previous one
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_code_reuse_code_long():
|
||||
"""test code reuse for long context"""
|
||||
|
||||
structural_context = """
|
||||
## User Requirement
|
||||
Run data analysis on sklearn Iris dataset, include a plot
|
||||
## Current Plan
|
||||
[
|
||||
{
|
||||
"task_id": "1",
|
||||
"dependent_task_ids": [],
|
||||
"instruction": "Load the Iris dataset from sklearn.",
|
||||
"task_type": "",
|
||||
"code": "from sklearn.datasets import load_iris\niris_data = load_iris()\niris_data['data'][0:5], iris_data['target'][0:5]",
|
||||
"result": "(array([[5.1, 3.5, 1.4, 0.2],\n [4.9, 3. , 1.4, 0.2],\n [4.7, 3.2, 1.3, 0.2],\n [4.6, 3.1, 1.5, 0.2],\n [5. , 3.6, 1.4, 0.2]]),\n array([0, 0, 0, 0, 0]))",
|
||||
"is_finished": true
|
||||
},
|
||||
{
|
||||
"task_id": "2",
|
||||
"dependent_task_ids": [
|
||||
"1"
|
||||
],
|
||||
"instruction": "Perform exploratory data analysis on the Iris dataset.",
|
||||
"task_type": "",
|
||||
"code": "",
|
||||
"result": "",
|
||||
"is_finished": false
|
||||
},
|
||||
{
|
||||
"task_id": "3",
|
||||
"dependent_task_ids": [
|
||||
"2"
|
||||
],
|
||||
"instruction": "Create a plot visualizing the Iris dataset features.",
|
||||
"task_type": "",
|
||||
"code": "",
|
||||
"result": "",
|
||||
"is_finished": false
|
||||
}
|
||||
]
|
||||
## Current Task
|
||||
{"task_id": "2", "dependent_task_ids": ["1"], "instruction": "Perform exploratory data analysis on the Iris dataset.", "task_type": "", "code": "", "result": "", "is_finished": false}
|
||||
"""
|
||||
context = [
|
||||
Message(content=structural_context, role="user"),
|
||||
]
|
||||
trials_num = 5
|
||||
trials = [WriteCodeWithoutTools().run(context=context, temperature=0.0) for _ in range(trials_num)]
|
||||
trial_results = await asyncio.gather(*trials)
|
||||
print(*trial_results, sep="\n\n***\n\n")
|
||||
success = [
|
||||
"load_iris" not in result["code"] and "iris_data" in result["code"] for result in trial_results
|
||||
] # should reuse iris_data from previous tasks
|
||||
success_rate = sum(success) / trials_num
|
||||
logger.info(f"success rate: {success_rate :.2f}")
|
||||
assert success_rate >= 0.8
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_code_reuse_code_long_for_wine():
|
||||
"""test code reuse for long context"""
|
||||
|
||||
structural_context = """
|
||||
## User Requirement
|
||||
Run data analysis on sklearn Wisconsin Breast Cancer dataset, include a plot, train a model to predict targets (20% as validation), and show validation accuracy
|
||||
## Current Plan
|
||||
[
|
||||
{
|
||||
"task_id": "1",
|
||||
"dependent_task_ids": [],
|
||||
"instruction": "Load the sklearn Wine recognition dataset and perform exploratory data analysis."
|
||||
"task_type": "",
|
||||
"code": "from sklearn.datasets import load_wine\n# Load the Wine recognition dataset\nwine_data = load_wine()\n# Perform exploratory data analysis\nwine_data.keys()",
|
||||
"result": "Truncated to show only the last 1000 characters\ndict_keys(['data', 'target', 'frame', 'target_names', 'DESCR', 'feature_names'])",
|
||||
"is_finished": true
|
||||
},
|
||||
{
|
||||
"task_id": "2",
|
||||
"dependent_task_ids": ["1"],
|
||||
"instruction": "Create a plot to visualize some aspect of the wine dataset."
|
||||
"task_type": "",
|
||||
"code": "",
|
||||
"result": "",
|
||||
"is_finished": false
|
||||
},
|
||||
{
|
||||
"task_id": "3",
|
||||
"dependent_task_ids": ["1"],
|
||||
"instruction": "Split the dataset into training and validation sets with a 20% validation size.",
|
||||
"task_type": "",
|
||||
"code": "",
|
||||
"result": "",
|
||||
"is_finished": false
|
||||
},
|
||||
{
|
||||
"task_id": "4",
|
||||
"dependent_task_ids": ["3"],
|
||||
"instruction": "Train a model on the training set to predict wine class.",
|
||||
"task_type": "",
|
||||
"code": "",
|
||||
"result": "",
|
||||
"is_finished": false
|
||||
},
|
||||
{
|
||||
"task_id": "5",
|
||||
"dependent_task_ids": ["4"],
|
||||
"instruction": "Evaluate the model on the validation set and report the accuracy.",
|
||||
"task_type": "",
|
||||
"code": "",
|
||||
"result": "",
|
||||
"is_finished": false
|
||||
}
|
||||
]
|
||||
## Current Task
|
||||
{"task_id": "2", "dependent_task_ids": ["1"], "instruction": "Create a plot to visualize some aspect of the Wine dataset.", "task_type": "", "code": "", "result": "", "is_finished": false}
|
||||
"""
|
||||
context = [
|
||||
Message(content=structural_context, role="user"),
|
||||
]
|
||||
trials_num = 5
|
||||
trials = [WriteCodeWithoutTools().run(context=context, temperature=0.0) for _ in range(trials_num)]
|
||||
trial_results = await asyncio.gather(*trials)
|
||||
print(*trial_results, sep="\n\n***\n\n")
|
||||
success = [
|
||||
"load_wine" not in result["code"] and "wine_data" in result["code"] for result in trial_results
|
||||
] # should reuse iris_data from previous tasks
|
||||
success_rate = sum(success) / trials_num
|
||||
logger.info(f"success rate: {success_rate :.2f}")
|
||||
assert success_rate >= 0.8
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.actions.ci.ask_review import AskReview
|
||||
from metagpt.actions.di.ask_review import AskReview
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.actions.ci.execute_nb_code import ExecuteNbCode, truncate
|
||||
from metagpt.actions.di.execute_nb_code import ExecuteNbCode
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -8,6 +8,7 @@ async def test_code_running():
|
|||
executor = ExecuteNbCode()
|
||||
output, is_success = await executor.run("print('hello world!')")
|
||||
assert is_success
|
||||
await executor.terminate()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -17,6 +18,7 @@ async def test_split_code_running():
|
|||
_ = await executor.run("z=x+y")
|
||||
output, is_success = await executor.run("assert z==3")
|
||||
assert is_success
|
||||
await executor.terminate()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -24,6 +26,7 @@ async def test_execute_error():
|
|||
executor = ExecuteNbCode()
|
||||
output, is_success = await executor.run("z=1/0")
|
||||
assert not is_success
|
||||
await executor.terminate()
|
||||
|
||||
|
||||
PLOT_CODE = """
|
||||
|
|
@ -52,21 +55,7 @@ async def test_plotting_code():
|
|||
executor = ExecuteNbCode()
|
||||
output, is_success = await executor.run(PLOT_CODE)
|
||||
assert is_success
|
||||
|
||||
|
||||
def test_truncate():
|
||||
# 代码执行成功
|
||||
output, is_success = truncate("hello world", 5, True)
|
||||
assert "Truncated to show only first 5 characters\nhello" in output
|
||||
assert is_success
|
||||
# 代码执行失败
|
||||
output, is_success = truncate("hello world", 5, False)
|
||||
assert "Truncated to show only last 5 characters\nworld" in output
|
||||
assert not is_success
|
||||
# 异步
|
||||
output, is_success = truncate("<coroutine object", 5, True)
|
||||
assert not is_success
|
||||
assert "await" in output
|
||||
await executor.terminate()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -76,6 +65,7 @@ async def test_run_with_timeout():
|
|||
message, success = await executor.run(code)
|
||||
assert not success
|
||||
assert message.startswith("Cell execution timed out")
|
||||
await executor.terminate()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -83,7 +73,7 @@ async def test_run_code_text():
|
|||
executor = ExecuteNbCode()
|
||||
message, success = await executor.run(code='print("This is a code!")', language="python")
|
||||
assert success
|
||||
assert message == "This is a code!\n"
|
||||
assert "This is a code!" in message
|
||||
message, success = await executor.run(code="# This is a code!", language="markdown")
|
||||
assert success
|
||||
assert message == "# This is a code!"
|
||||
|
|
@ -91,19 +81,22 @@ async def test_run_code_text():
|
|||
message, success = await executor.run(code=mix_text, language="markdown")
|
||||
assert success
|
||||
assert message == mix_text
|
||||
await executor.terminate()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_terminate():
|
||||
executor = ExecuteNbCode()
|
||||
await executor.run(code='print("This is a code!")', language="python")
|
||||
is_kernel_alive = await executor.nb_client.km.is_alive()
|
||||
assert is_kernel_alive
|
||||
await executor.terminate()
|
||||
import time
|
||||
|
||||
time.sleep(2)
|
||||
assert executor.nb_client.km is None
|
||||
@pytest.mark.parametrize(
|
||||
"k", [(1), (5)]
|
||||
) # k=1 to test a single regular terminate, k>1 to test terminate under continuous run
|
||||
async def test_terminate(k):
|
||||
for _ in range(k):
|
||||
executor = ExecuteNbCode()
|
||||
await executor.run(code='print("This is a code!")', language="python")
|
||||
is_kernel_alive = await executor.nb_client.km.is_alive()
|
||||
assert is_kernel_alive
|
||||
await executor.terminate()
|
||||
assert executor.nb_client.km is None
|
||||
assert executor.nb_client.kc is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -114,3 +107,22 @@ async def test_reset():
|
|||
assert is_kernel_alive
|
||||
await executor.reset()
|
||||
assert executor.nb_client.km is None
|
||||
await executor.terminate()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_outputs():
|
||||
executor = ExecuteNbCode()
|
||||
code = """
|
||||
import pandas as pd
|
||||
df = pd.DataFrame({'ID': [1,2,3], 'NAME': ['a', 'b', 'c']})
|
||||
print(df.columns)
|
||||
print(f"columns num:{len(df.columns)}")
|
||||
print(df['DUMMPY_ID'])
|
||||
"""
|
||||
output, is_success = await executor.run(code)
|
||||
assert not is_success
|
||||
assert "Index(['ID', 'NAME'], dtype='object')" in output
|
||||
assert "KeyError: 'DUMMPY_ID'" in output
|
||||
assert "columns num:2" in output
|
||||
await executor.terminate()
|
||||
79
tests/metagpt/actions/di/test_write_analysis_code.py
Normal file
79
tests/metagpt/actions/di/test_write_analysis_code.py
Normal file
|
|
@ -0,0 +1,79 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.actions.di.write_analysis_code import WriteAnalysisCode
|
||||
from metagpt.schema import Message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_code_with_plan():
|
||||
write_code = WriteAnalysisCode()
|
||||
|
||||
user_requirement = "Run data analysis on sklearn Iris dataset, include a plot"
|
||||
plan_status = "\n## Finished Tasks\n### code\n```python\n\n```\n\n### execution result\n\n\n## Current Task\nLoad the sklearn Iris dataset and perform exploratory data analysis\n\n## Task Guidance\nWrite complete code for 'Current Task'. And avoid duplicating code from 'Finished Tasks', such as repeated import of packages, reading data, etc.\nSpecifically, \nThe current task is about exploratory data analysis, please note the following:\n- Distinguish column types with `select_dtypes` for tailored analysis and visualization, such as correlation.\n- Remember to `import numpy as np` before using Numpy functions.\n\n"
|
||||
|
||||
code = await write_code.run(user_requirement=user_requirement, plan_status=plan_status)
|
||||
assert len(code) > 0
|
||||
assert "sklearn" in code
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_code_with_tools():
|
||||
write_code = WriteAnalysisCode()
|
||||
|
||||
user_requirement = "Preprocess sklearn Wine recognition dataset and train a model to predict wine class (20% as validation), and show validation accuracy."
|
||||
tool_info = """
|
||||
## Capabilities
|
||||
- You can utilize pre-defined tools in any code lines from 'Available Tools' in the form of Python class or function.
|
||||
- You can freely combine the use of any other public packages, like sklearn, numpy, pandas, etc..
|
||||
|
||||
## Available Tools:
|
||||
Each tool is described in JSON format. When you call a tool, import the tool from its path first.
|
||||
{'FillMissingValue': {'type': 'class', 'description': 'Completing missing values with simple strategies.', 'methods': {'__init__': {'type': 'function', 'description': 'Initialize self. ', 'signature': '(self, features: \'list\', strategy: "Literal[\'mean\', \'median\', \'most_frequent\', \'constant\']" = \'mean\', fill_value=None)', 'parameters': 'Args: features (list): Columns to be processed. strategy (Literal["mean", "median", "most_frequent", "constant"], optional): The imputation strategy, notice \'mean\' and \'median\' can only be used for numeric features. Defaults to \'mean\'. fill_value (int, optional): Fill_value is used to replace all occurrences of missing_values. Defaults to None.'}, 'fit': {'type': 'function', 'description': 'Fit a model to be used in subsequent transform. ', 'signature': "(self, df: 'pd.DataFrame')", 'parameters': 'Args: df (pd.DataFrame): The input DataFrame.'}, 'fit_transform': {'type': 'function', 'description': 'Fit and transform the input DataFrame. ', 'signature': "(self, df: 'pd.DataFrame') -> 'pd.DataFrame'", 'parameters': 'Args: df (pd.DataFrame): The input DataFrame. Returns: pd.DataFrame: The transformed DataFrame.'}, 'transform': {'type': 'function', 'description': 'Transform the input DataFrame with the fitted model. ', 'signature': "(self, df: 'pd.DataFrame') -> 'pd.DataFrame'", 'parameters': 'Args: df (pd.DataFrame): The input DataFrame. Returns: pd.DataFrame: The transformed DataFrame.'}}, 'tool_path': 'metagpt/tools/libs/data_preprocess.py'}
|
||||
"""
|
||||
|
||||
code = await write_code.run(user_requirement=user_requirement, tool_info=tool_info)
|
||||
assert len(code) > 0
|
||||
assert "metagpt.tools.libs" in code
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_debug_with_reflection():
|
||||
user_requirement = "read a dataset test.csv and print its head"
|
||||
|
||||
plan_status = """
|
||||
## Finished Tasks
|
||||
### code
|
||||
```python
|
||||
```
|
||||
|
||||
### execution result
|
||||
|
||||
## Current Task
|
||||
import pandas and load the dataset from 'test.csv'.
|
||||
|
||||
## Task Guidance
|
||||
Write complete code for 'Current Task'. And avoid duplicating code from 'Finished Tasks', such as repeated import of packages, reading data, etc.
|
||||
Specifically,
|
||||
"""
|
||||
|
||||
wrong_code = """import pandas as pd\ndata = pd.read_excel('test.csv')\ndata""" # use read_excel to read a csv
|
||||
error = """
|
||||
Traceback (most recent call last):
|
||||
File "<stdin>", line 2, in <module>
|
||||
File "/Users/gary/miniconda3/envs/py39_scratch/lib/python3.9/site-packages/pandas/io/excel/_base.py", line 478, in read_excel
|
||||
io = ExcelFile(io, storage_options=storage_options, engine=engine)
|
||||
File "/Users/gary/miniconda3/envs/py39_scratch/lib/python3.9/site-packages/pandas/io/excel/_base.py", line 1500, in __init__
|
||||
raise ValueError(
|
||||
ValueError: Excel file format cannot be determined, you must specify an engine manually.
|
||||
"""
|
||||
working_memory = [
|
||||
Message(content=wrong_code, role="assistant"),
|
||||
Message(content=error, role="user"),
|
||||
]
|
||||
new_code = await WriteAnalysisCode().run(
|
||||
user_requirement=user_requirement,
|
||||
plan_status=plan_status,
|
||||
working_memory=working_memory,
|
||||
use_reflection=True,
|
||||
)
|
||||
assert "read_csv" in new_code # should correct read_excel to read_csv
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.actions.ci.write_plan import (
|
||||
from metagpt.actions.di.write_plan import (
|
||||
Plan,
|
||||
Task,
|
||||
WritePlan,
|
||||
|
|
@ -23,12 +23,10 @@ def test_precheck_update_plan_from_rsp():
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("use_tools", [(False), (True)])
|
||||
async def test_write_plan(use_tools):
|
||||
async def test_write_plan():
|
||||
rsp = await WritePlan().run(
|
||||
context=[Message("run analysis on sklearn iris dataset", role="user")], use_tools=use_tools
|
||||
context=[Message("Run data analysis on sklearn Iris dataset, include a plot", role="user")]
|
||||
)
|
||||
|
||||
assert "task_id" in rsp
|
||||
assert "instruction" in rsp
|
||||
assert "json" not in rsp # the output should be the content inside ```json ```
|
||||
|
|
@ -37,7 +37,7 @@ DESIGN = {
|
|||
}
|
||||
|
||||
|
||||
TASKS = {
|
||||
TASK = {
|
||||
"Required Python packages": ["pygame==2.0.1"],
|
||||
"Required Other language third-party packages": ["No third-party dependencies required"],
|
||||
"Logic Analysis": [
|
||||
|
|
|
|||
|
|
@ -9,8 +9,10 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.actions.design_api import WriteDesign
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import Message
|
||||
from tests.data.incremental_dev_project.mock import DESIGN_SAMPLE, REFINED_PRD_JSON
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -25,3 +27,16 @@ async def test_design_api(context):
|
|||
logger.info(result)
|
||||
|
||||
assert result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refined_design_api(context):
|
||||
await context.repo.docs.prd.save(filename="1.txt", content=str(REFINED_PRD_JSON))
|
||||
await context.repo.docs.system_design.save(filename="1.txt", content=DESIGN_SAMPLE)
|
||||
|
||||
design_api = WriteDesign(context=context, llm=LLM())
|
||||
|
||||
result = await design_api.run(Message(content="", instruct_content=None))
|
||||
logger.info(result)
|
||||
|
||||
assert result
|
||||
|
|
|
|||
|
|
@ -9,13 +9,19 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.actions.project_management import WriteTasks
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import Message
|
||||
from tests.data.incremental_dev_project.mock import (
|
||||
REFINED_DESIGN_JSON,
|
||||
REFINED_PRD_JSON,
|
||||
TASK_SAMPLE,
|
||||
)
|
||||
from tests.metagpt.actions.mock_json import DESIGN, PRD
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_design_api(context):
|
||||
async def test_task(context):
|
||||
await context.repo.docs.prd.save("1.txt", content=str(PRD))
|
||||
await context.repo.docs.system_design.save("1.txt", content=str(DESIGN))
|
||||
logger.info(context.git_repo)
|
||||
|
|
@ -26,3 +32,19 @@ async def test_design_api(context):
|
|||
logger.info(result)
|
||||
|
||||
assert result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refined_task(context):
|
||||
await context.repo.docs.prd.save("2.txt", content=str(REFINED_PRD_JSON))
|
||||
await context.repo.docs.system_design.save("2.txt", content=str(REFINED_DESIGN_JSON))
|
||||
await context.repo.docs.task.save("2.txt", content=TASK_SAMPLE)
|
||||
|
||||
logger.info(context.git_repo)
|
||||
|
||||
action = WriteTasks(context=context, llm=LLM())
|
||||
|
||||
result = await action.run(Message(content="", instruct_content=None))
|
||||
logger.info(result)
|
||||
|
||||
assert result
|
||||
|
|
|
|||
|
|
@ -10,13 +10,14 @@ from openai._models import BaseModel
|
|||
|
||||
from metagpt.actions.action_node import ActionNode, dict_to_markdown
|
||||
from metagpt.actions.project_management import NEW_REQ_TEMPLATE
|
||||
from metagpt.actions.project_management_an import REFINED_PM_NODE
|
||||
from metagpt.actions.project_management_an import PM_NODE, REFINED_PM_NODE
|
||||
from metagpt.llm import LLM
|
||||
from tests.data.incremental_dev_project.mock import (
|
||||
REFINED_DESIGN_JSON,
|
||||
REFINED_TASKS_JSON,
|
||||
TASKS_SAMPLE,
|
||||
REFINED_TASK_JSON,
|
||||
TASK_SAMPLE,
|
||||
)
|
||||
from tests.metagpt.actions.mock_json import TASK
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
|
|
@ -24,20 +25,40 @@ def llm():
|
|||
return LLM()
|
||||
|
||||
|
||||
def mock_refined_tasks_json():
|
||||
return REFINED_TASKS_JSON
|
||||
def mock_refined_task_json():
|
||||
return REFINED_TASK_JSON
|
||||
|
||||
|
||||
def mock_task_json():
|
||||
return TASK
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_project_management_an(mocker):
|
||||
root = ActionNode.from_children(
|
||||
"ProjectManagement", [ActionNode(key="", expected_type=str, instruction="", example="")]
|
||||
)
|
||||
root.instruct_content = BaseModel()
|
||||
root.instruct_content.model_dump = mock_task_json
|
||||
mocker.patch("metagpt.actions.project_management_an.PM_NODE.fill", return_value=root)
|
||||
|
||||
node = await PM_NODE.fill(dict_to_markdown(REFINED_DESIGN_JSON), llm)
|
||||
|
||||
assert "Logic Analysis" in node.instruct_content.model_dump()
|
||||
assert "Task list" in node.instruct_content.model_dump()
|
||||
assert "Shared Knowledge" in node.instruct_content.model_dump()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_project_management_an_inc(mocker):
|
||||
root = ActionNode.from_children(
|
||||
"RefinedProjectManagement", [ActionNode(key="", expected_type=str, instruction="", example="")]
|
||||
)
|
||||
root.instruct_content = BaseModel()
|
||||
root.instruct_content.model_dump = mock_refined_tasks_json
|
||||
root.instruct_content.model_dump = mock_refined_task_json
|
||||
mocker.patch("metagpt.actions.project_management_an.REFINED_PM_NODE.fill", return_value=root)
|
||||
|
||||
prompt = NEW_REQ_TEMPLATE.format(old_task=TASKS_SAMPLE, context=dict_to_markdown(REFINED_DESIGN_JSON))
|
||||
prompt = NEW_REQ_TEMPLATE.format(old_task=TASK_SAMPLE, context=dict_to_markdown(REFINED_DESIGN_JSON))
|
||||
node = await REFINED_PM_NODE.fill(prompt, llm)
|
||||
|
||||
assert "Refined Logic Analysis" in node.instruct_content.model_dump()
|
||||
|
|
|
|||
|
|
@ -23,6 +23,8 @@ async def test_rebuild(context):
|
|||
context=context,
|
||||
)
|
||||
await action.run()
|
||||
rows = await action.graph_db.select()
|
||||
assert rows
|
||||
assert context.repo.docs.graph_repo.changed_files
|
||||
|
||||
|
||||
|
|
@ -45,6 +47,12 @@ def test_align_path(path, direction, diff, want):
|
|||
("/Users/x/github/MetaGPT/metagpt", "/Users/x/github/MetaGPT/metagpt", "=", "."),
|
||||
("/Users/x/github/MetaGPT", "/Users/x/github/MetaGPT/metagpt", "-", "metagpt"),
|
||||
("/Users/x/github/MetaGPT/metagpt", "/Users/x/github/MetaGPT", "+", "metagpt"),
|
||||
(
|
||||
"/Users/x/github/MetaGPT-env/lib/python3.9/site-packages/moviepy",
|
||||
"/Users/x/github/MetaGPT-env/lib/python3.9/site-packages/",
|
||||
"+",
|
||||
"moviepy",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_diff_path(path_root, package_root, want_direction, want_diff):
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
@Time : 2024/1/4
|
||||
@Author : mashenquan
|
||||
@File : test_rebuild_sequence_view.py
|
||||
@Desc : Unit tests for reconstructing the sequence diagram from a source code project.
|
||||
"""
|
||||
from pathlib import Path
|
||||
|
||||
|
|
@ -14,25 +15,40 @@ from metagpt.const import GRAPH_REPO_FILE_REPO
|
|||
from metagpt.llm import LLM
|
||||
from metagpt.utils.common import aread
|
||||
from metagpt.utils.git_repository import ChangeType
|
||||
from metagpt.utils.graph_repository import SPO
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skip
|
||||
async def test_rebuild(context):
|
||||
async def test_rebuild(context, mocker):
|
||||
# Mock
|
||||
data = await aread(filename=Path(__file__).parent / "../../data/graph_db/networkx.json")
|
||||
data = await aread(filename=Path(__file__).parent / "../../data/graph_db/networkx.class_view.json")
|
||||
graph_db_filename = Path(context.repo.workdir.name).with_suffix(".json")
|
||||
await context.repo.docs.graph_repo.save(filename=str(graph_db_filename), content=data)
|
||||
context.git_repo.add_change({f"{GRAPH_REPO_FILE_REPO}/{graph_db_filename}": ChangeType.UNTRACTED})
|
||||
context.git_repo.commit("commit1")
|
||||
# mock_spo = SPO(
|
||||
# subject="metagpt/startup.py:__name__:__main__",
|
||||
# predicate="has_page_info",
|
||||
# object_='{"lineno":78,"end_lineno":79,"type_name":"ast.If","tokens":["__name__","__main__"],"properties":{}}',
|
||||
# )
|
||||
mock_spo = SPO(
|
||||
subject="metagpt/management/skill_manager.py:__name__:__main__",
|
||||
predicate="has_page_info",
|
||||
object_='{"lineno":113,"end_lineno":116,"type_name":"ast.If","tokens":["__name__","__main__"],"properties":{}}',
|
||||
)
|
||||
mocker.patch.object(RebuildSequenceView, "_search_main_entry", return_value=[mock_spo])
|
||||
|
||||
action = RebuildSequenceView(
|
||||
name="RedBean",
|
||||
i_context=str(Path(__file__).parent.parent.parent.parent / "metagpt"),
|
||||
i_context=str(
|
||||
Path(__file__).parent.parent.parent.parent / "metagpt/management/skill_manager.py:__name__:__main__"
|
||||
),
|
||||
llm=LLM(),
|
||||
context=context,
|
||||
)
|
||||
await action.run()
|
||||
rows = await action.graph_db.select()
|
||||
assert rows
|
||||
assert context.repo.docs.graph_repo.changed_files
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -6,14 +6,13 @@
|
|||
@File : test_summarize_code.py
|
||||
@Modifiled By: mashenquan, 2023-12-6. Unit test for summarize_code.py
|
||||
"""
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.actions.summarize_code import SummarizeCode
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import CodeSummarizeContext
|
||||
from tests.mock.mock_llm import MockLLM
|
||||
|
||||
DESIGN_CONTENT = """
|
||||
{"Implementation approach": "To develop this snake game, we will use the Python language and choose the Pygame library. Pygame is an open-source Python module collection specifically designed for writing video games. It provides functionalities such as displaying images and playing sounds, making it suitable for creating intuitive and responsive user interfaces. We will ensure efficient game logic to prevent any delays during gameplay. The scoring system will be simple, with the snake gaining points for each food it eats. We will use Pygame's event handling system to implement pause and resume functionality, as well as high-score tracking. The difficulty will increase by speeding up the snake's movement. In the initial version, we will focus on single-player mode and consider adding multiplayer mode and customizable skins in future updates. Based on the new requirement, we will also add a moving obstacle that appears randomly. If the snake eats this obstacle, the game will end. If the snake does not eat the obstacle, it will disappear after 5 seconds. For this, we need to add mechanisms for obstacle generation, movement, and disappearance in the game logic.", "Project_name": "snake_game", "File list": ["main.py", "game.py", "snake.py", "food.py", "obstacle.py", "scoreboard.py", "constants.py", "assets/styles.css", "assets/index.html"], "Data structures and interfaces": "```mermaid\n classDiagram\n class Game{\n +int score\n +int speed\n +bool game_over\n +bool paused\n +Snake snake\n +Food food\n +Obstacle obstacle\n +Scoreboard scoreboard\n +start_game() void\n +pause_game() void\n +resume_game() void\n +end_game() void\n +increase_difficulty() void\n +update() void\n +render() void\n Game()\n }\n class Snake{\n +list body_parts\n +str direction\n +bool grow\n +move() void\n +grow() void\n +check_collision() bool\n Snake()\n }\n class Food{\n +tuple position\n +spawn() void\n Food()\n }\n class Obstacle{\n +tuple position\n +int lifetime\n +bool active\n +spawn() void\n +move() void\n +check_collision() bool\n +disappear() void\n Obstacle()\n }\n class Scoreboard{\n +int high_score\n +update_score(int) void\n +reset_score() void\n +load_high_score() void\n +save_high_score() void\n Scoreboard()\n }\n class Constants{\n }\n Game \"1\" -- \"1\" Snake: has\n Game \"1\" -- \"1\" Food: has\n Game \"1\" -- \"1\" Obstacle: has\n Game \"1\" -- \"1\" Scoreboard: has\n ```", "Program call flow": "```sequenceDiagram\n participant M as Main\n participant G as Game\n participant S as Snake\n participant F as Food\n participant O as Obstacle\n participant SB as Scoreboard\n M->>G: start_game()\n loop game loop\n G->>S: move()\n G->>S: check_collision()\n G->>F: spawn()\n G->>O: spawn()\n G->>O: move()\n G->>O: check_collision()\n G->>O: disappear()\n G->>SB: update_score(score)\n G->>G: update()\n G->>G: render()\n alt if paused\n M->>G: pause_game()\n M->>G: resume_game()\n end\n alt if game_over\n G->>M: end_game()\n end\n end\n```", "Anything UNCLEAR": "There is no need for further clarification as the requirements are already clear."}
|
||||
|
|
@ -175,12 +174,87 @@ class Snake:
|
|||
|
||||
"""
|
||||
|
||||
mock_rsp = """
|
||||
```mermaid
|
||||
classDiagram
|
||||
class Game{
|
||||
+int score
|
||||
+int level
|
||||
+Snake snake
|
||||
+Food food
|
||||
+start_game() void
|
||||
+initialize_game() void
|
||||
+game_loop() void
|
||||
+update() void
|
||||
+draw() void
|
||||
+handle_events() void
|
||||
+check_collision() void
|
||||
+increase_score() void
|
||||
+increase_level() void
|
||||
+game_over() void
|
||||
Game()
|
||||
}
|
||||
class Snake{
|
||||
+list body
|
||||
+tuple direction
|
||||
+move() void
|
||||
+change_direction(direction: str) void
|
||||
+grow() void
|
||||
+get_head() tuple
|
||||
+get_body() list
|
||||
Snake()
|
||||
}
|
||||
class Food{
|
||||
+tuple position
|
||||
+generate() void
|
||||
+get_position() tuple
|
||||
Food()
|
||||
}
|
||||
Game "1" -- "1" Snake: has
|
||||
Game "1" -- "1" Food: has
|
||||
```
|
||||
|
||||
```sequenceDiagram
|
||||
participant M as Main
|
||||
participant G as Game
|
||||
participant S as Snake
|
||||
participant F as Food
|
||||
M->>G: start_game()
|
||||
G->>G: initialize_game()
|
||||
G->>G: game_loop()
|
||||
G->>S: move()
|
||||
G->>S: change_direction()
|
||||
G->>S: grow()
|
||||
G->>F: generate()
|
||||
S->>S: move()
|
||||
S->>S: change_direction()
|
||||
S->>S: grow()
|
||||
F->>F: generate()
|
||||
```
|
||||
|
||||
## Summary
|
||||
The code consists of the main game logic, including the Game, Snake, and Food classes. The game loop is responsible for updating and drawing the game elements, handling events, checking collisions, and managing the game state. The Snake class handles the movement, growth, and direction changes of the snake, while the Food class is responsible for generating and tracking the position of food items.
|
||||
|
||||
## TODOs
|
||||
- Modify 'game.py' to add the implementation of obstacle handling and interaction with the game loop.
|
||||
- Implement 'obstacle.py' to include the methods for spawning, moving, and disappearing of obstacles, as well as collision detection with the snake.
|
||||
- Update 'main.py' to initialize the obstacle and incorporate it into the game loop.
|
||||
- Update the mermaid call flow diagram to include the interaction with the obstacle.
|
||||
|
||||
```python
|
||||
{
|
||||
"files_to_modify": {
|
||||
"game.py": "Add obstacle handling and interaction with the game loop",
|
||||
"obstacle.py": "Implement obstacle class with necessary methods",
|
||||
"main.py": "Initialize the obstacle and incorporate it into the game loop"
|
||||
}
|
||||
}
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_summarize_code(context):
|
||||
git_dir = Path(__file__).parent / f"unittest/{uuid.uuid4().hex}"
|
||||
git_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
async def test_summarize_code(context, mocker):
|
||||
context.src_workspace = context.git_repo.workdir / "src"
|
||||
await context.repo.docs.system_design.save(filename="1.json", content=DESIGN_CONTENT)
|
||||
await context.repo.docs.task.save(filename="1.json", content=TASK_CONTENT)
|
||||
|
|
@ -189,6 +263,7 @@ async def test_summarize_code(context):
|
|||
await context.repo.srcs.save(filename="game.py", content=GAME_PY)
|
||||
await context.repo.srcs.save(filename="main.py", content=MAIN_PY)
|
||||
await context.repo.srcs.save(filename="snake.py", content=SNAKE_PY)
|
||||
mocker.patch.object(MockLLM, "_mock_rsp", return_value=mock_rsp)
|
||||
|
||||
all_files = context.repo.srcs.all_files
|
||||
summarization_context = CodeSummarizeContext(
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@
|
|||
@File : test_write_code.py
|
||||
@Modifiled By: mashenquan, 2023-12-6. According to RFC 135
|
||||
"""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
|
@ -14,10 +14,27 @@ import pytest
|
|||
from metagpt.actions.write_code import WriteCode
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import CodingContext, Document
|
||||
from metagpt.utils.common import aread
|
||||
from metagpt.utils.common import CodeParser, aread
|
||||
from tests.data.incremental_dev_project.mock import (
|
||||
CODE_PLAN_AND_CHANGE_SAMPLE,
|
||||
REFINED_CODE_INPUT_SAMPLE,
|
||||
REFINED_DESIGN_JSON,
|
||||
REFINED_TASK_JSON,
|
||||
)
|
||||
from tests.metagpt.actions.mock_markdown import TASKS_2, WRITE_CODE_PROMPT_SAMPLE
|
||||
|
||||
|
||||
def setup_inc_workdir(context, inc: bool = False):
|
||||
"""setup incremental workdir for testing"""
|
||||
context.src_workspace = context.git_repo.workdir / "src"
|
||||
if inc:
|
||||
context.config.inc = inc
|
||||
context.repo.old_workspace = context.repo.git_repo.workdir / "old"
|
||||
context.config.project_path = "old"
|
||||
|
||||
return context
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_code(context):
|
||||
# Prerequisites
|
||||
|
|
@ -81,5 +98,66 @@ async def test_write_code_deps(context):
|
|||
assert rsp.code_doc.content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_refined_code(context, git_dir):
|
||||
# Prerequisites
|
||||
context = setup_inc_workdir(context, inc=True)
|
||||
await context.repo.docs.system_design.save(filename="1.json", content=json.dumps(REFINED_DESIGN_JSON))
|
||||
await context.repo.docs.task.save(filename="1.json", content=json.dumps(REFINED_TASK_JSON))
|
||||
await context.repo.docs.code_plan_and_change.save(
|
||||
filename="1.json", content=json.dumps(CODE_PLAN_AND_CHANGE_SAMPLE)
|
||||
)
|
||||
|
||||
# old_workspace contains the legacy code
|
||||
await context.repo.with_src_path(context.repo.old_workspace).srcs.save(
|
||||
filename="game.py", content=CodeParser.parse_code(block="", text=REFINED_CODE_INPUT_SAMPLE)
|
||||
)
|
||||
|
||||
ccontext = CodingContext(
|
||||
filename="game.py",
|
||||
design_doc=await context.repo.docs.system_design.get(filename="1.json"),
|
||||
task_doc=await context.repo.docs.task.get(filename="1.json"),
|
||||
code_plan_and_change_doc=await context.repo.docs.code_plan_and_change.get(filename="1.json"),
|
||||
code_doc=Document(filename="game.py", content="", root_path="src"),
|
||||
)
|
||||
coding_doc = Document(root_path="src", filename="game.py", content=ccontext.json())
|
||||
|
||||
action = WriteCode(i_context=coding_doc, context=context)
|
||||
rsp = await action.run()
|
||||
assert rsp
|
||||
assert rsp.code_doc.content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_codes(context):
|
||||
# Prerequisites
|
||||
context = setup_inc_workdir(context, inc=True)
|
||||
for filename in ["game.py", "ui.py"]:
|
||||
await context.repo.with_src_path(context.src_workspace).srcs.save(
|
||||
filename=filename, content=f"# {filename}\nnew code ..."
|
||||
)
|
||||
await context.repo.with_src_path(context.repo.old_workspace).srcs.save(
|
||||
filename=filename, content=f"# {filename}\nlegacy code ..."
|
||||
)
|
||||
|
||||
await context.repo.with_src_path(context.repo.old_workspace).srcs.save(
|
||||
filename="gui.py", content="# gui.py\nlegacy code ..."
|
||||
)
|
||||
await context.repo.with_src_path(context.repo.old_workspace).srcs.save(
|
||||
filename="main.py", content='# main.py\nif __name__ == "__main__":\n main()'
|
||||
)
|
||||
task_doc = Document(filename="1.json", content=json.dumps(REFINED_TASK_JSON))
|
||||
|
||||
context.repo = context.repo.with_src_path(context.src_workspace)
|
||||
# Ready to write gui.py
|
||||
codes = await WriteCode.get_codes(task_doc=task_doc, exclude="gui.py", project_repo=context.repo)
|
||||
codes_inc = await WriteCode.get_codes(task_doc=task_doc, exclude="gui.py", project_repo=context.repo, use_inc=True)
|
||||
|
||||
logger.info(codes)
|
||||
logger.info(codes_inc)
|
||||
assert codes
|
||||
assert codes_inc
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
|
|
@ -5,6 +5,8 @@
|
|||
@Author : mannaandpoem
|
||||
@File : test_write_code_plan_and_change_an.py
|
||||
"""
|
||||
import json
|
||||
|
||||
import pytest
|
||||
from openai._models import BaseModel
|
||||
|
||||
|
|
@ -14,15 +16,21 @@ from metagpt.actions.write_code_plan_and_change_an import (
|
|||
REFINED_TEMPLATE,
|
||||
WriteCodePlanAndChange,
|
||||
)
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import CodePlanAndChangeContext
|
||||
from metagpt.utils.common import CodeParser
|
||||
from tests.data.incremental_dev_project.mock import (
|
||||
CODE_PLAN_AND_CHANGE_SAMPLE,
|
||||
DESIGN_SAMPLE,
|
||||
NEW_REQUIREMENT_SAMPLE,
|
||||
REFINED_CODE_INPUT_SAMPLE,
|
||||
REFINED_CODE_SAMPLE,
|
||||
TASKS_SAMPLE,
|
||||
REFINED_DESIGN_JSON,
|
||||
REFINED_PRD_JSON,
|
||||
REFINED_TASK_JSON,
|
||||
TASK_SAMPLE,
|
||||
)
|
||||
from tests.metagpt.actions.test_write_code import setup_inc_workdir
|
||||
|
||||
|
||||
def mock_code_plan_and_change():
|
||||
|
|
@ -30,27 +38,35 @@ def mock_code_plan_and_change():
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_code_plan_and_change_an(mocker):
|
||||
async def test_write_code_plan_and_change_an(mocker, context, git_dir):
|
||||
context = setup_inc_workdir(context, inc=True)
|
||||
await context.repo.docs.prd.save(filename="2.json", content=json.dumps(REFINED_PRD_JSON))
|
||||
await context.repo.docs.system_design.save(filename="2.json", content=json.dumps(REFINED_DESIGN_JSON))
|
||||
await context.repo.docs.task.save(filename="2.json", content=json.dumps(REFINED_TASK_JSON))
|
||||
|
||||
await context.repo.with_src_path(context.repo.old_workspace).srcs.save(
|
||||
filename="game.py", content=CodeParser.parse_code(block="", text=REFINED_CODE_INPUT_SAMPLE)
|
||||
)
|
||||
|
||||
root = ActionNode.from_children(
|
||||
"WriteCodePlanAndChange", [ActionNode(key="", expected_type=str, instruction="", example="")]
|
||||
)
|
||||
root.instruct_content = BaseModel()
|
||||
root.instruct_content.model_dump = mock_code_plan_and_change
|
||||
mocker.patch("metagpt.actions.write_code_plan_and_change_an.WriteCodePlanAndChange.run", return_value=root)
|
||||
|
||||
requirement = "New requirement"
|
||||
prd_filename = "prd.md"
|
||||
design_filename = "design.md"
|
||||
task_filename = "task.md"
|
||||
code_plan_and_change_context = CodePlanAndChangeContext(
|
||||
requirement=requirement,
|
||||
prd_filename=prd_filename,
|
||||
design_filename=design_filename,
|
||||
task_filename=task_filename,
|
||||
mocker.patch(
|
||||
"metagpt.actions.write_code_plan_and_change_an.WRITE_CODE_PLAN_AND_CHANGE_NODE.fill", return_value=root
|
||||
)
|
||||
node = await WriteCodePlanAndChange(i_context=code_plan_and_change_context).run()
|
||||
|
||||
assert "Code Plan And Change" in node.instruct_content.model_dump()
|
||||
code_plan_and_change_context = CodePlanAndChangeContext(
|
||||
requirement="New requirement",
|
||||
prd_filename="2.json",
|
||||
design_filename="2.json",
|
||||
task_filename="2.json",
|
||||
)
|
||||
node = await WriteCodePlanAndChange(i_context=code_plan_and_change_context, context=context).run()
|
||||
|
||||
assert "Development Plan" in node.instruct_content.model_dump()
|
||||
assert "Incremental Change" in node.instruct_content.model_dump()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -60,7 +76,7 @@ async def test_refine_code(mocker):
|
|||
user_requirement=NEW_REQUIREMENT_SAMPLE,
|
||||
code_plan_and_change=CODE_PLAN_AND_CHANGE_SAMPLE,
|
||||
design=DESIGN_SAMPLE,
|
||||
task=TASKS_SAMPLE,
|
||||
task=TASK_SAMPLE,
|
||||
code=REFINED_CODE_INPUT_SAMPLE,
|
||||
logs="",
|
||||
feedback="",
|
||||
|
|
@ -69,3 +85,25 @@ async def test_refine_code(mocker):
|
|||
)
|
||||
code = await WriteCode().write_code(prompt=prompt)
|
||||
assert "def" in code
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_old_code(context, git_dir):
|
||||
context = setup_inc_workdir(context, inc=True)
|
||||
await context.repo.with_src_path(context.repo.old_workspace).srcs.save(
|
||||
filename="game.py", content=REFINED_CODE_INPUT_SAMPLE
|
||||
)
|
||||
|
||||
code_plan_and_change_context = CodePlanAndChangeContext(
|
||||
requirement="New requirement",
|
||||
prd_filename="1.json",
|
||||
design_filename="1.json",
|
||||
task_filename="1.json",
|
||||
)
|
||||
action = WriteCodePlanAndChange(context=context, i_context=code_plan_and_change_context)
|
||||
|
||||
old_codes = await action.get_old_codes()
|
||||
logger.info(old_codes)
|
||||
|
||||
assert "def" in old_codes
|
||||
assert "class" in old_codes
|
||||
|
|
|
|||
|
|
@ -32,5 +32,28 @@ def add(a, b):
|
|||
print(f"输出内容: {captured.out}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_code_review_inc(capfd, context):
|
||||
context.src_workspace = context.repo.workdir / "srcs"
|
||||
context.config.inc = True
|
||||
code = """
|
||||
def add(a, b):
|
||||
return a +
|
||||
"""
|
||||
code_plan_and_change = """
|
||||
def add(a, b):
|
||||
- return a +
|
||||
+ return a + b
|
||||
"""
|
||||
coding_context = CodingContext(
|
||||
filename="math.py",
|
||||
design_doc=Document(content="编写一个从a加b的函数,返回a+b"),
|
||||
code_doc=Document(content=code),
|
||||
code_plan_and_change_doc=Document(content=code_plan_and_change),
|
||||
)
|
||||
|
||||
await WriteCodeReview(i_context=coding_context, context=context).run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@
|
|||
@File : test_write_prd.py
|
||||
@Modified By: mashenquan, 2023-11-1. According to Chapter 2.2.1 and 2.2.2 of RFC 116, replace `handle` with `run`.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.actions import UserRequirement, WritePRD
|
||||
|
|
@ -15,6 +16,8 @@ from metagpt.roles.product_manager import ProductManager
|
|||
from metagpt.roles.role import RoleReactMode
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.common import any_to_str
|
||||
from tests.data.incremental_dev_project.mock import NEW_REQUIREMENT_SAMPLE, PRD_SAMPLE
|
||||
from tests.metagpt.actions.test_write_code import setup_inc_workdir
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -34,5 +37,41 @@ async def test_write_prd(new_filename, context):
|
|||
assert product_manager.context.repo.docs.prd.changed_files
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_prd_inc(new_filename, context, git_dir):
|
||||
context = setup_inc_workdir(context, inc=True)
|
||||
await context.repo.docs.prd.save("1.txt", PRD_SAMPLE)
|
||||
await context.repo.docs.save(filename=REQUIREMENT_FILENAME, content=NEW_REQUIREMENT_SAMPLE)
|
||||
|
||||
action = WritePRD(context=context)
|
||||
prd = await action.run(Message(content=NEW_REQUIREMENT_SAMPLE, instruct_content=None))
|
||||
logger.info(NEW_REQUIREMENT_SAMPLE)
|
||||
logger.info(prd)
|
||||
|
||||
# Assert the prd is not None or empty
|
||||
assert prd is not None
|
||||
assert prd.content != ""
|
||||
assert "Refined Requirements" in prd.content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fix_debug(new_filename, context, git_dir):
|
||||
context.src_workspace = context.git_repo.workdir / context.git_repo.workdir.name
|
||||
|
||||
await context.repo.with_src_path(context.src_workspace).srcs.save(
|
||||
filename="main.py", content='if __name__ == "__main__":\nmain()'
|
||||
)
|
||||
requirements = "Please fix the bug in the code."
|
||||
await context.repo.docs.save(filename=REQUIREMENT_FILENAME, content=requirements)
|
||||
action = WritePRD(context=context)
|
||||
|
||||
prd = await action.run(Message(content=requirements, instruct_content=None))
|
||||
logger.info(prd)
|
||||
|
||||
# Assert the prd is not None or empty
|
||||
assert prd is not None
|
||||
assert prd.content != ""
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from metagpt.document_store.chromadb_store import ChromaStore
|
|||
def test_chroma_store():
|
||||
"""FIXME:chroma使用感觉很诡异,一用Python就挂,测试用例里也是"""
|
||||
# 创建 ChromaStore 实例,使用 'sample_collection' 集合
|
||||
document_store = ChromaStore("sample_collection_1")
|
||||
document_store = ChromaStore("sample_collection_1", get_or_create=True)
|
||||
|
||||
# 使用 write 方法添加多个文档
|
||||
document_store.write(
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@
|
|||
@File : test_faiss_store.py
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from metagpt.const import EXAMPLE_PATH
|
||||
|
|
@ -14,9 +15,24 @@ from metagpt.logs import logger
|
|||
from metagpt.roles import Sales
|
||||
|
||||
|
||||
def mock_openai_embed_documents(self, texts: list[str], show_progress: bool = False) -> list[list[float]]:
|
||||
num = len(texts)
|
||||
embeds = np.random.randint(1, 100, size=(num, 1536)) # 1536: openai embedding dim
|
||||
embeds = (embeds - embeds.mean(axis=0)) / embeds.std(axis=0)
|
||||
return embeds.tolist()
|
||||
|
||||
|
||||
def mock_openai_embed_document(self, text: str) -> list[float]:
|
||||
embeds = mock_openai_embed_documents(self, [text])
|
||||
return embeds[0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_json():
|
||||
store = FaissStore(EXAMPLE_PATH / "example.json")
|
||||
async def test_search_json(mocker):
|
||||
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embeddings", mock_openai_embed_documents)
|
||||
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embedding", mock_openai_embed_document)
|
||||
|
||||
store = FaissStore(EXAMPLE_PATH / "data/search_kb/example.json")
|
||||
role = Sales(profile="Sales", store=store)
|
||||
query = "Which facial cleanser is good for oily skin?"
|
||||
result = await role.run(query)
|
||||
|
|
@ -24,8 +40,11 @@ async def test_search_json():
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_xlsx():
|
||||
store = FaissStore(EXAMPLE_PATH / "example.xlsx")
|
||||
async def test_search_xlsx(mocker):
|
||||
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embeddings", mock_openai_embed_documents)
|
||||
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embedding", mock_openai_embed_document)
|
||||
|
||||
store = FaissStore(EXAMPLE_PATH / "data/search_kb/example.xlsx", meta_col="Answer", content_col="Question")
|
||||
role = Sales(profile="Sales", store=store)
|
||||
query = "Which facial cleanser is good for oily skin?"
|
||||
result = await role.run(query)
|
||||
|
|
@ -33,8 +52,11 @@ async def test_search_xlsx():
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write():
|
||||
store = FaissStore(EXAMPLE_PATH / "example.xlsx", meta_col="Answer", content_col="Question")
|
||||
async def test_write(mocker):
|
||||
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embeddings", mock_openai_embed_documents)
|
||||
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embedding", mock_openai_embed_document)
|
||||
|
||||
store = FaissStore(EXAMPLE_PATH / "data/search_kb/example.xlsx", meta_col="Answer", content_col="Question")
|
||||
_faiss_store = store.write()
|
||||
assert _faiss_store.docstore
|
||||
assert _faiss_store.index
|
||||
assert _faiss_store.storage_context.docstore
|
||||
assert _faiss_store.storage_context.vector_store.client
|
||||
|
|
|
|||
|
|
@ -1,14 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the unittest of MincraftExtEnv
|
||||
|
||||
|
||||
from metagpt.environment.mincraft_env.const import MC_CKPT_DIR
|
||||
from metagpt.environment.mincraft_env.mincraft_ext_env import MincraftExtEnv
|
||||
|
||||
|
||||
def test_mincraft_ext_env():
|
||||
ext_env = MincraftExtEnv()
|
||||
assert ext_env.server, f"{ext_env.server_host}:{ext_env.server_port}"
|
||||
assert MC_CKPT_DIR.joinpath("skill/code").exists()
|
||||
assert ext_env.warm_up.get("optional_inventory_items") == 7
|
||||
|
|
@ -0,0 +1,14 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the unittest of MinecraftExtEnv
|
||||
|
||||
|
||||
from metagpt.environment.minecraft_env.const import MC_CKPT_DIR
|
||||
from metagpt.environment.minecraft_env.minecraft_ext_env import MinecraftExtEnv
|
||||
|
||||
|
||||
def test_minecraft_ext_env():
|
||||
ext_env = MinecraftExtEnv()
|
||||
assert ext_env.server, f"{ext_env.server_host}:{ext_env.server_port}"
|
||||
assert MC_CKPT_DIR.joinpath("skill/code").exists()
|
||||
assert ext_env.warm_up.get("optional_inventory_items") == 7
|
||||
42
tests/metagpt/memory/mock_text_embed.py
Normal file
42
tests/metagpt/memory/mock_text_embed.py
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
|
||||
import numpy as np
|
||||
|
||||
dim = 1536 # openai embedding dim
|
||||
embed_zeros_arrr = np.zeros(shape=[1, dim]).tolist()
|
||||
embed_ones_arrr = np.ones(shape=[1, dim]).tolist()
|
||||
|
||||
text_embed_arr = [
|
||||
{"text": "Write a cli snake game", "embed": embed_zeros_arrr}, # mock data, same as below
|
||||
{"text": "Write a game of cli snake", "embed": embed_zeros_arrr},
|
||||
{"text": "Write a 2048 web game", "embed": embed_ones_arrr},
|
||||
{"text": "Write a Battle City", "embed": embed_ones_arrr},
|
||||
{
|
||||
"text": "The user has requested the creation of a command-line interface (CLI) snake game",
|
||||
"embed": embed_zeros_arrr,
|
||||
},
|
||||
{"text": "The request is command-line interface (CLI) snake game", "embed": embed_zeros_arrr},
|
||||
{
|
||||
"text": "Incorporate basic features of a snake game such as scoring and increasing difficulty",
|
||||
"embed": embed_ones_arrr,
|
||||
},
|
||||
]
|
||||
|
||||
text_idx_dict = {item["text"]: idx for idx, item in enumerate(text_embed_arr)}
|
||||
|
||||
|
||||
def mock_openai_embed_documents(self, texts: list[str], show_progress: bool = False) -> list[list[float]]:
|
||||
idx = text_idx_dict.get(texts[0])
|
||||
embed = text_embed_arr[idx].get("embed")
|
||||
return embed
|
||||
|
||||
|
||||
def mock_openai_embed_document(self, text: str) -> list[float]:
|
||||
embeds = mock_openai_embed_documents(self, [text])
|
||||
return embeds[0]
|
||||
|
||||
|
||||
async def mock_openai_aembed_document(self, text: str) -> list[float]:
|
||||
return mock_openai_embed_document(self, text)
|
||||
|
|
@ -4,20 +4,29 @@
|
|||
@Desc : unittest of `metagpt/memory/longterm_memory.py`
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.actions import UserRequirement
|
||||
from metagpt.config2 import config
|
||||
from metagpt.memory.longterm_memory import LongTermMemory
|
||||
from metagpt.roles.role import RoleContext
|
||||
from metagpt.schema import Message
|
||||
|
||||
os.environ.setdefault("OPENAI_API_KEY", config.get_openai_llm().api_key)
|
||||
from tests.metagpt.memory.mock_text_embed import (
|
||||
mock_openai_aembed_document,
|
||||
mock_openai_embed_document,
|
||||
mock_openai_embed_documents,
|
||||
text_embed_arr,
|
||||
)
|
||||
|
||||
|
||||
def test_ltm_search():
|
||||
@pytest.mark.asyncio
|
||||
async def test_ltm_search(mocker):
|
||||
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embeddings", mock_openai_embed_documents)
|
||||
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embedding", mock_openai_embed_document)
|
||||
mocker.patch(
|
||||
"llama_index.embeddings.openai.base.OpenAIEmbedding._aget_query_embedding", mock_openai_aembed_document
|
||||
)
|
||||
|
||||
role_id = "UTUserLtm(Product Manager)"
|
||||
from metagpt.environment import Environment
|
||||
|
||||
|
|
@ -27,41 +36,26 @@ def test_ltm_search():
|
|||
ltm = LongTermMemory()
|
||||
ltm.recover_memory(role_id, rc)
|
||||
|
||||
idea = "Write a cli snake game"
|
||||
idea = text_embed_arr[0].get("text", "Write a cli snake game")
|
||||
message = Message(role="User", content=idea, cause_by=UserRequirement)
|
||||
news = ltm.find_news([message])
|
||||
news = await ltm.find_news([message])
|
||||
assert len(news) == 1
|
||||
ltm.add(message)
|
||||
|
||||
sim_idea = "Write a game of cli snake"
|
||||
sim_idea = text_embed_arr[1].get("text", "Write a game of cli snake")
|
||||
|
||||
sim_message = Message(role="User", content=sim_idea, cause_by=UserRequirement)
|
||||
news = ltm.find_news([sim_message])
|
||||
news = await ltm.find_news([sim_message])
|
||||
assert len(news) == 0
|
||||
ltm.add(sim_message)
|
||||
|
||||
new_idea = "Write a 2048 web game"
|
||||
new_idea = text_embed_arr[2].get("text", "Write a 2048 web game")
|
||||
new_message = Message(role="User", content=new_idea, cause_by=UserRequirement)
|
||||
news = ltm.find_news([new_message])
|
||||
news = await ltm.find_news([new_message])
|
||||
assert len(news) == 1
|
||||
ltm.add(new_message)
|
||||
|
||||
# restore from local index
|
||||
ltm_new = LongTermMemory()
|
||||
ltm_new.recover_memory(role_id, rc)
|
||||
news = ltm_new.find_news([message])
|
||||
assert len(news) == 0
|
||||
|
||||
ltm_new.recover_memory(role_id, rc)
|
||||
news = ltm_new.find_news([sim_message])
|
||||
assert len(news) == 0
|
||||
|
||||
new_idea = "Write a Battle City"
|
||||
new_message = Message(role="User", content=new_idea, cause_by=UserRequirement)
|
||||
news = ltm_new.find_news([new_message])
|
||||
assert len(news) == 1
|
||||
|
||||
ltm_new.clear()
|
||||
ltm.clear()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -4,56 +4,75 @@
|
|||
@Desc : the unittests of metagpt/memory/memory_storage.py
|
||||
"""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.actions import UserRequirement, WritePRD
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.config2 import config
|
||||
from metagpt.const import DATA_PATH
|
||||
from metagpt.memory.memory_storage import MemoryStorage
|
||||
from metagpt.schema import Message
|
||||
|
||||
os.environ.setdefault("OPENAI_API_KEY", config.get_openai_llm().api_key)
|
||||
from tests.metagpt.memory.mock_text_embed import (
|
||||
mock_openai_aembed_document,
|
||||
mock_openai_embed_document,
|
||||
mock_openai_embed_documents,
|
||||
text_embed_arr,
|
||||
)
|
||||
|
||||
|
||||
def test_idea_message():
|
||||
idea = "Write a cli snake game"
|
||||
@pytest.mark.asyncio
|
||||
async def test_idea_message(mocker):
|
||||
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embeddings", mock_openai_embed_documents)
|
||||
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embedding", mock_openai_embed_document)
|
||||
mocker.patch(
|
||||
"llama_index.embeddings.openai.base.OpenAIEmbedding._aget_query_embedding", mock_openai_aembed_document
|
||||
)
|
||||
|
||||
idea = text_embed_arr[0].get("text", "Write a cli snake game")
|
||||
role_id = "UTUser1(Product Manager)"
|
||||
message = Message(role="User", content=idea, cause_by=UserRequirement)
|
||||
|
||||
shutil.rmtree(Path(DATA_PATH / f"role_mem/{role_id}/"), ignore_errors=True)
|
||||
|
||||
memory_storage: MemoryStorage = MemoryStorage()
|
||||
messages = memory_storage.recover_memory(role_id)
|
||||
assert len(messages) == 0
|
||||
memory_storage.recover_memory(role_id)
|
||||
|
||||
memory_storage.add(message)
|
||||
assert memory_storage.is_initialized is True
|
||||
|
||||
sim_idea = "Write a game of cli snake"
|
||||
sim_idea = text_embed_arr[1].get("text", "Write a game of cli snake")
|
||||
sim_message = Message(role="User", content=sim_idea, cause_by=UserRequirement)
|
||||
new_messages = memory_storage.search_dissimilar(sim_message)
|
||||
assert len(new_messages) == 0 # similar, return []
|
||||
new_messages = await memory_storage.search_similar(sim_message)
|
||||
assert len(new_messages) == 1 # similar, return []
|
||||
|
||||
new_idea = "Write a 2048 web game"
|
||||
new_idea = text_embed_arr[2].get("text", "Write a 2048 web game")
|
||||
new_message = Message(role="User", content=new_idea, cause_by=UserRequirement)
|
||||
new_messages = memory_storage.search_dissimilar(new_message)
|
||||
assert new_messages[0].content == message.content
|
||||
new_messages = await memory_storage.search_similar(new_message)
|
||||
assert len(new_messages) == 0
|
||||
|
||||
memory_storage.clean()
|
||||
assert memory_storage.is_initialized is False
|
||||
|
||||
|
||||
def test_actionout_message():
|
||||
@pytest.mark.asyncio
|
||||
async def test_actionout_message(mocker):
|
||||
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embeddings", mock_openai_embed_documents)
|
||||
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embedding", mock_openai_embed_document)
|
||||
mocker.patch(
|
||||
"llama_index.embeddings.openai.base.OpenAIEmbedding._aget_query_embedding", mock_openai_aembed_document
|
||||
)
|
||||
|
||||
out_mapping = {"field1": (str, ...), "field2": (List[str], ...)}
|
||||
out_data = {"field1": "field1 value", "field2": ["field2 value1", "field2 value2"]}
|
||||
ic_obj = ActionNode.create_model_class("prd", out_mapping)
|
||||
|
||||
role_id = "UTUser2(Architect)"
|
||||
content = "The user has requested the creation of a command-line interface (CLI) snake game"
|
||||
content = text_embed_arr[4].get(
|
||||
"text", "The user has requested the creation of a command-line interface (CLI) snake game"
|
||||
)
|
||||
message = Message(
|
||||
content=content, instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD
|
||||
) # WritePRD as test action
|
||||
|
|
@ -61,21 +80,22 @@ def test_actionout_message():
|
|||
shutil.rmtree(Path(DATA_PATH / f"role_mem/{role_id}/"), ignore_errors=True)
|
||||
|
||||
memory_storage: MemoryStorage = MemoryStorage()
|
||||
messages = memory_storage.recover_memory(role_id)
|
||||
assert len(messages) == 0
|
||||
memory_storage.recover_memory(role_id)
|
||||
|
||||
memory_storage.add(message)
|
||||
assert memory_storage.is_initialized is True
|
||||
|
||||
sim_conent = "The request is command-line interface (CLI) snake game"
|
||||
sim_conent = text_embed_arr[5].get("text", "The request is command-line interface (CLI) snake game")
|
||||
sim_message = Message(content=sim_conent, instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD)
|
||||
new_messages = memory_storage.search_dissimilar(sim_message)
|
||||
assert len(new_messages) == 0 # similar, return []
|
||||
new_messages = await memory_storage.search_similar(sim_message)
|
||||
assert len(new_messages) == 1 # similar, return []
|
||||
|
||||
new_conent = "Incorporate basic features of a snake game such as scoring and increasing difficulty"
|
||||
new_conent = text_embed_arr[6].get(
|
||||
"text", "Incorporate basic features of a snake game such as scoring and increasing difficulty"
|
||||
)
|
||||
new_message = Message(content=new_conent, instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD)
|
||||
new_messages = memory_storage.search_dissimilar(new_message)
|
||||
assert new_messages[0].content == message.content
|
||||
new_messages = await memory_storage.search_similar(new_message)
|
||||
assert len(new_messages) == 0
|
||||
|
||||
memory_storage.clean()
|
||||
assert memory_storage.is_initialized is False
|
||||
|
|
|
|||
|
|
@ -42,3 +42,21 @@ mock_llm_config_zhipu = LLMConfig(
|
|||
model="mock_zhipu_model",
|
||||
proxy="http://localhost:8080",
|
||||
)
|
||||
|
||||
|
||||
mock_llm_config_spark = LLMConfig(
|
||||
api_type="spark",
|
||||
app_id="xxx",
|
||||
api_key="xxx",
|
||||
api_secret="xxx",
|
||||
domain="generalv2",
|
||||
base_url="wss://spark-api.xf-yun.com/v3.1/chat",
|
||||
)
|
||||
|
||||
mock_llm_config_qianfan = LLMConfig(api_type="qianfan", access_key="xxx", secret_key="xxx", model="ERNIE-Bot-turbo")
|
||||
|
||||
mock_llm_config_dashscope = LLMConfig(api_type="dashscope", api_key="xxx", model="qwen-max")
|
||||
|
||||
mock_llm_config_anthropic = LLMConfig(
|
||||
api_type="anthropic", api_key="xxx", base_url="https://api.anthropic.com", model="claude-3-opus-20240229"
|
||||
)
|
||||
|
|
|
|||
185
tests/metagpt/provider/req_resp_const.py
Normal file
185
tests/metagpt/provider/req_resp_const.py
Normal file
|
|
@ -0,0 +1,185 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : default request & response data for provider unittest
|
||||
|
||||
|
||||
from anthropic.types import (
|
||||
ContentBlock,
|
||||
ContentBlockDeltaEvent,
|
||||
Message,
|
||||
MessageStartEvent,
|
||||
TextDelta,
|
||||
)
|
||||
from anthropic.types import Usage as AnthropicUsage
|
||||
from dashscope.api_entities.dashscope_response import (
|
||||
DashScopeAPIResponse,
|
||||
GenerationOutput,
|
||||
GenerationResponse,
|
||||
GenerationUsage,
|
||||
)
|
||||
from openai.types.chat.chat_completion import (
|
||||
ChatCompletion,
|
||||
ChatCompletionMessage,
|
||||
Choice,
|
||||
)
|
||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
from openai.types.chat.chat_completion_chunk import Choice as AChoice
|
||||
from openai.types.chat.chat_completion_chunk import ChoiceDelta
|
||||
from openai.types.completion_usage import CompletionUsage
|
||||
from qianfan.resources.typing import QfResponse
|
||||
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
|
||||
prompt = "who are you?"
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
|
||||
resp_cont_tmpl = "I'm {name}"
|
||||
default_resp_cont = resp_cont_tmpl.format(name="GPT")
|
||||
|
||||
|
||||
# part of whole ChatCompletion of openai like structure
|
||||
def get_part_chat_completion(name: str) -> dict:
|
||||
part_chat_completion = {
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": resp_cont_tmpl.format(name=name),
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
"usage": {"completion_tokens": 22, "prompt_tokens": 19, "total_tokens": 41},
|
||||
}
|
||||
return part_chat_completion
|
||||
|
||||
|
||||
def get_openai_chat_completion(name: str) -> ChatCompletion:
|
||||
openai_chat_completion = ChatCompletion(
|
||||
id="cmpl-a6652c1bb181caae8dd19ad8",
|
||||
model="xx/xxx",
|
||||
object="chat.completion",
|
||||
created=1703300855,
|
||||
choices=[
|
||||
Choice(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message=ChatCompletionMessage(role="assistant", content=resp_cont_tmpl.format(name=name)),
|
||||
logprobs=None,
|
||||
)
|
||||
],
|
||||
usage=CompletionUsage(completion_tokens=110, prompt_tokens=92, total_tokens=202),
|
||||
)
|
||||
return openai_chat_completion
|
||||
|
||||
|
||||
def get_openai_chat_completion_chunk(name: str, usage_as_dict: bool = False) -> ChatCompletionChunk:
|
||||
usage = CompletionUsage(completion_tokens=110, prompt_tokens=92, total_tokens=202)
|
||||
usage = usage if not usage_as_dict else usage.model_dump()
|
||||
openai_chat_completion_chunk = ChatCompletionChunk(
|
||||
id="cmpl-a6652c1bb181caae8dd19ad8",
|
||||
model="xx/xxx",
|
||||
object="chat.completion.chunk",
|
||||
created=1703300855,
|
||||
choices=[
|
||||
AChoice(
|
||||
delta=ChoiceDelta(role="assistant", content=resp_cont_tmpl.format(name=name)),
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
logprobs=None,
|
||||
)
|
||||
],
|
||||
usage=usage,
|
||||
)
|
||||
return openai_chat_completion_chunk
|
||||
|
||||
|
||||
# For gemini
|
||||
gemini_messages = [{"role": "user", "parts": prompt}]
|
||||
|
||||
|
||||
# For QianFan
|
||||
qf_jsonbody_dict = {
|
||||
"id": "as-4v1h587fyv",
|
||||
"object": "chat.completion",
|
||||
"created": 1695021339,
|
||||
"result": "",
|
||||
"is_truncated": False,
|
||||
"need_clear_history": False,
|
||||
"usage": {"prompt_tokens": 7, "completion_tokens": 15, "total_tokens": 22},
|
||||
}
|
||||
|
||||
|
||||
def get_qianfan_response(name: str) -> QfResponse:
|
||||
qf_jsonbody_dict["result"] = resp_cont_tmpl.format(name=name)
|
||||
return QfResponse(code=200, body=qf_jsonbody_dict)
|
||||
|
||||
|
||||
# For DashScope
|
||||
def get_dashscope_response(name: str) -> GenerationResponse:
|
||||
return GenerationResponse.from_api_response(
|
||||
DashScopeAPIResponse(
|
||||
status_code=200,
|
||||
output=GenerationOutput(
|
||||
**{
|
||||
"text": "",
|
||||
"finish_reason": "",
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "stop",
|
||||
"message": {"role": "assistant", "content": resp_cont_tmpl.format(name=name)},
|
||||
}
|
||||
],
|
||||
}
|
||||
),
|
||||
usage=GenerationUsage(**{"input_tokens": 12, "output_tokens": 98, "total_tokens": 110}),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# For Anthropic
|
||||
def get_anthropic_response(name: str, stream: bool = False) -> Message:
|
||||
if stream:
|
||||
return [
|
||||
MessageStartEvent(
|
||||
message=Message(
|
||||
id="xxx",
|
||||
model=name,
|
||||
role="assistant",
|
||||
type="message",
|
||||
content=[ContentBlock(text="", type="text")],
|
||||
usage=AnthropicUsage(input_tokens=10, output_tokens=10),
|
||||
),
|
||||
type="message_start",
|
||||
),
|
||||
ContentBlockDeltaEvent(
|
||||
index=0,
|
||||
delta=TextDelta(text=resp_cont_tmpl.format(name=name), type="text_delta"),
|
||||
type="content_block_delta",
|
||||
),
|
||||
]
|
||||
else:
|
||||
return Message(
|
||||
id="xxx",
|
||||
model=name,
|
||||
role="assistant",
|
||||
type="message",
|
||||
content=[ContentBlock(text=resp_cont_tmpl.format(name=name), type="text")],
|
||||
usage=AnthropicUsage(input_tokens=10, output_tokens=10),
|
||||
)
|
||||
|
||||
|
||||
# For llm general chat functions call
|
||||
async def llm_general_chat_funcs_test(llm: BaseLLM, prompt: str, messages: list[dict], resp_cont: str):
|
||||
resp = await llm.aask(prompt, stream=False)
|
||||
assert resp == resp_cont
|
||||
|
||||
resp = await llm.aask(prompt)
|
||||
assert resp == resp_cont
|
||||
|
||||
resp = await llm.acompletion_text(messages, stream=False)
|
||||
assert resp == resp_cont
|
||||
|
||||
resp = await llm.acompletion_text(messages, stream=True)
|
||||
assert resp == resp_cont
|
||||
|
|
@ -2,31 +2,45 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the unittest of Claude2
|
||||
|
||||
|
||||
import pytest
|
||||
from anthropic.resources.completions import Completion
|
||||
|
||||
from metagpt.provider.anthropic_api import Claude2
|
||||
from tests.metagpt.provider.mock_llm_config import mock_llm_config
|
||||
from metagpt.provider.anthropic_api import AnthropicLLM
|
||||
from tests.metagpt.provider.mock_llm_config import mock_llm_config_anthropic
|
||||
from tests.metagpt.provider.req_resp_const import (
|
||||
get_anthropic_response,
|
||||
llm_general_chat_funcs_test,
|
||||
messages,
|
||||
prompt,
|
||||
resp_cont_tmpl,
|
||||
)
|
||||
|
||||
prompt = "who are you"
|
||||
resp = "I'am Claude2"
|
||||
name = "claude-3-opus-20240229"
|
||||
resp_cont = resp_cont_tmpl.format(name=name)
|
||||
|
||||
|
||||
def mock_anthropic_completions_create(self, model: str, prompt: str, max_tokens_to_sample: int) -> Completion:
|
||||
return Completion(id="xx", completion=resp, model="claude-2", stop_reason="stop_sequence", type="completion")
|
||||
async def mock_anthropic_messages_create(
|
||||
self, messages: list[dict], model: str, stream: bool = True, max_tokens: int = None, system: str = None
|
||||
) -> Completion:
|
||||
if stream:
|
||||
|
||||
async def aresp_iterator():
|
||||
resps = get_anthropic_response(name, stream=True)
|
||||
for resp in resps:
|
||||
yield resp
|
||||
|
||||
async def mock_anthropic_acompletions_create(self, model: str, prompt: str, max_tokens_to_sample: int) -> Completion:
|
||||
return Completion(id="xx", completion=resp, model="claude-2", stop_reason="stop_sequence", type="completion")
|
||||
|
||||
|
||||
def test_claude2_ask(mocker):
|
||||
mocker.patch("anthropic.resources.completions.Completions.create", mock_anthropic_completions_create)
|
||||
assert resp == Claude2(mock_llm_config).ask(prompt)
|
||||
return aresp_iterator()
|
||||
else:
|
||||
return get_anthropic_response(name)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_claude2_aask(mocker):
|
||||
mocker.patch("anthropic.resources.completions.AsyncCompletions.create", mock_anthropic_acompletions_create)
|
||||
assert resp == await Claude2(mock_llm_config).aask(prompt)
|
||||
async def test_anthropic_acompletion(mocker):
|
||||
mocker.patch("anthropic.resources.messages.AsyncMessages.create", mock_anthropic_messages_create)
|
||||
|
||||
anthropic_llm = AnthropicLLM(mock_llm_config_anthropic)
|
||||
|
||||
resp = await anthropic_llm.acompletion(messages)
|
||||
assert resp.content[0].text == resp_cont
|
||||
|
||||
await llm_general_chat_funcs_test(anthropic_llm, prompt, messages, resp_cont)
|
||||
|
|
|
|||
|
|
@ -11,21 +11,13 @@ import pytest
|
|||
from metagpt.configs.llm_config import LLMConfig
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.schema import Message
|
||||
from tests.metagpt.provider.req_resp_const import (
|
||||
default_resp_cont,
|
||||
get_part_chat_completion,
|
||||
prompt,
|
||||
)
|
||||
|
||||
default_chat_resp = {
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "I'am GPT",
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
]
|
||||
}
|
||||
prompt_msg = "who are you"
|
||||
resp_content = default_chat_resp["choices"][0]["message"]["content"]
|
||||
name = "GPT"
|
||||
|
||||
|
||||
class MockBaseLLM(BaseLLM):
|
||||
|
|
@ -33,16 +25,19 @@ class MockBaseLLM(BaseLLM):
|
|||
pass
|
||||
|
||||
def completion(self, messages: list[dict], timeout=3):
|
||||
return default_chat_resp
|
||||
return get_part_chat_completion(name)
|
||||
|
||||
async def _achat_completion(self, messages: list[dict], timeout=3):
|
||||
pass
|
||||
|
||||
async def acompletion(self, messages: list[dict], timeout=3):
|
||||
return default_chat_resp
|
||||
return get_part_chat_completion(name)
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout: int = 3) -> str:
|
||||
pass
|
||||
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str:
|
||||
return resp_content
|
||||
|
||||
async def close(self):
|
||||
return default_chat_resp
|
||||
return default_resp_cont
|
||||
|
||||
|
||||
def test_base_llm():
|
||||
|
|
@ -86,25 +81,25 @@ def test_base_llm():
|
|||
choice_text = base_llm.get_choice_text(openai_funccall_resp)
|
||||
assert choice_text == openai_funccall_resp["choices"][0]["message"]["content"]
|
||||
|
||||
# resp = base_llm.ask(prompt_msg)
|
||||
# assert resp == resp_content
|
||||
# resp = base_llm.ask(prompt)
|
||||
# assert resp == default_resp_cont
|
||||
|
||||
# resp = base_llm.ask_batch([prompt_msg])
|
||||
# assert resp == resp_content
|
||||
# resp = base_llm.ask_batch([prompt])
|
||||
# assert resp == default_resp_cont
|
||||
|
||||
# resp = base_llm.ask_code([prompt_msg])
|
||||
# assert resp == resp_content
|
||||
# resp = base_llm.ask_code([prompt])
|
||||
# assert resp == default_resp_cont
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_base_llm():
|
||||
base_llm = MockBaseLLM()
|
||||
|
||||
resp = await base_llm.aask(prompt_msg)
|
||||
assert resp == resp_content
|
||||
resp = await base_llm.aask(prompt)
|
||||
assert resp == default_resp_cont
|
||||
|
||||
resp = await base_llm.aask_batch([prompt_msg])
|
||||
assert resp == resp_content
|
||||
resp = await base_llm.aask_batch([prompt])
|
||||
assert resp == default_resp_cont
|
||||
|
||||
# resp = await base_llm.aask_code([prompt_msg])
|
||||
# assert resp == resp_content
|
||||
# resp = await base_llm.aask_code([prompt])
|
||||
# assert resp == default_resp_cont
|
||||
|
|
|
|||
73
tests/metagpt/provider/test_dashscope_api.py
Normal file
73
tests/metagpt/provider/test_dashscope_api.py
Normal file
|
|
@ -0,0 +1,73 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the unittest of DashScopeLLM
|
||||
|
||||
from typing import AsyncGenerator, Union
|
||||
|
||||
import pytest
|
||||
from dashscope.api_entities.dashscope_response import GenerationResponse
|
||||
|
||||
from metagpt.provider.dashscope_api import DashScopeLLM
|
||||
from tests.metagpt.provider.mock_llm_config import mock_llm_config_dashscope
|
||||
from tests.metagpt.provider.req_resp_const import (
|
||||
get_dashscope_response,
|
||||
llm_general_chat_funcs_test,
|
||||
messages,
|
||||
prompt,
|
||||
resp_cont_tmpl,
|
||||
)
|
||||
|
||||
name = "qwen-max"
|
||||
resp_cont = resp_cont_tmpl.format(name=name)
|
||||
|
||||
|
||||
@classmethod
|
||||
def mock_dashscope_call(
|
||||
cls,
|
||||
messages: list[dict],
|
||||
model: str,
|
||||
api_key: str,
|
||||
result_format: str,
|
||||
incremental_output: bool = True,
|
||||
stream: bool = False,
|
||||
) -> GenerationResponse:
|
||||
return get_dashscope_response(name)
|
||||
|
||||
|
||||
@classmethod
|
||||
async def mock_dashscope_acall(
|
||||
cls,
|
||||
messages: list[dict],
|
||||
model: str,
|
||||
api_key: str,
|
||||
result_format: str,
|
||||
incremental_output: bool = True,
|
||||
stream: bool = False,
|
||||
) -> Union[AsyncGenerator[GenerationResponse, None], GenerationResponse]:
|
||||
resps = [get_dashscope_response(name)]
|
||||
|
||||
if stream:
|
||||
|
||||
async def aresp_iterator(resps: list[GenerationResponse]):
|
||||
for resp in resps:
|
||||
yield resp
|
||||
|
||||
return aresp_iterator(resps)
|
||||
else:
|
||||
return resps[0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dashscope_acompletion(mocker):
|
||||
mocker.patch("dashscope.aigc.generation.Generation.call", mock_dashscope_call)
|
||||
mocker.patch("metagpt.provider.dashscope_api.AGeneration.acall", mock_dashscope_acall)
|
||||
|
||||
dashscope_llm = DashScopeLLM(mock_llm_config_dashscope)
|
||||
|
||||
resp = dashscope_llm.completion(messages)
|
||||
assert resp.choices[0]["message"]["content"] == resp_cont
|
||||
|
||||
resp = await dashscope_llm.acompletion(messages)
|
||||
assert resp.choices[0]["message"]["content"] == resp_cont
|
||||
|
||||
await llm_general_chat_funcs_test(dashscope_llm, prompt, messages, resp_cont)
|
||||
|
|
@ -1,114 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the unittest of fireworks api
|
||||
|
||||
import pytest
|
||||
from openai.types.chat.chat_completion import (
|
||||
ChatCompletion,
|
||||
ChatCompletionMessage,
|
||||
Choice,
|
||||
)
|
||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
from openai.types.chat.chat_completion_chunk import Choice as AChoice
|
||||
from openai.types.chat.chat_completion_chunk import ChoiceDelta
|
||||
from openai.types.completion_usage import CompletionUsage
|
||||
|
||||
from metagpt.provider.fireworks_api import (
|
||||
MODEL_GRADE_TOKEN_COSTS,
|
||||
FireworksCostManager,
|
||||
FireworksLLM,
|
||||
)
|
||||
from metagpt.utils.cost_manager import Costs
|
||||
from tests.metagpt.provider.mock_llm_config import mock_llm_config
|
||||
|
||||
resp_content = "I'm fireworks"
|
||||
default_resp = ChatCompletion(
|
||||
id="cmpl-a6652c1bb181caae8dd19ad8",
|
||||
model="accounts/fireworks/models/llama-v2-13b-chat",
|
||||
object="chat.completion",
|
||||
created=1703300855,
|
||||
choices=[
|
||||
Choice(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message=ChatCompletionMessage(role="assistant", content=resp_content),
|
||||
logprobs=None,
|
||||
)
|
||||
],
|
||||
usage=CompletionUsage(completion_tokens=110, prompt_tokens=92, total_tokens=202),
|
||||
)
|
||||
|
||||
default_resp_chunk = ChatCompletionChunk(
|
||||
id=default_resp.id,
|
||||
model=default_resp.model,
|
||||
object="chat.completion.chunk",
|
||||
created=default_resp.created,
|
||||
choices=[
|
||||
AChoice(
|
||||
delta=ChoiceDelta(content=resp_content, role="assistant"),
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
logprobs=None,
|
||||
)
|
||||
],
|
||||
usage=dict(default_resp.usage),
|
||||
)
|
||||
|
||||
prompt_msg = "who are you"
|
||||
messages = [{"role": "user", "content": prompt_msg}]
|
||||
|
||||
|
||||
def test_fireworks_costmanager():
|
||||
cost_manager = FireworksCostManager()
|
||||
assert MODEL_GRADE_TOKEN_COSTS["-1"] == cost_manager.model_grade_token_costs("test")
|
||||
assert MODEL_GRADE_TOKEN_COSTS["-1"] == cost_manager.model_grade_token_costs("xxx-81b-chat")
|
||||
assert MODEL_GRADE_TOKEN_COSTS["16"] == cost_manager.model_grade_token_costs("llama-v2-13b-chat")
|
||||
assert MODEL_GRADE_TOKEN_COSTS["16"] == cost_manager.model_grade_token_costs("xxx-15.5b-chat")
|
||||
assert MODEL_GRADE_TOKEN_COSTS["16"] == cost_manager.model_grade_token_costs("xxx-16b-chat")
|
||||
assert MODEL_GRADE_TOKEN_COSTS["80"] == cost_manager.model_grade_token_costs("xxx-80b-chat")
|
||||
assert MODEL_GRADE_TOKEN_COSTS["mixtral-8x7b"] == cost_manager.model_grade_token_costs("mixtral-8x7b-chat")
|
||||
|
||||
cost_manager.update_cost(prompt_tokens=500000, completion_tokens=500000, model="llama-v2-13b-chat")
|
||||
assert cost_manager.total_cost == 0.5
|
||||
|
||||
|
||||
async def mock_openai_acompletions_create(self, stream: bool = False, **kwargs) -> ChatCompletionChunk:
|
||||
if stream:
|
||||
|
||||
class Iterator(object):
|
||||
async def __aiter__(self):
|
||||
yield default_resp_chunk
|
||||
|
||||
return Iterator()
|
||||
else:
|
||||
return default_resp
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fireworks_acompletion(mocker):
|
||||
mocker.patch("openai.resources.chat.completions.AsyncCompletions.create", mock_openai_acompletions_create)
|
||||
|
||||
fireworks_gpt = FireworksLLM(mock_llm_config)
|
||||
fireworks_gpt.model = "llama-v2-13b-chat"
|
||||
|
||||
fireworks_gpt._update_costs(
|
||||
usage=CompletionUsage(prompt_tokens=500000, completion_tokens=500000, total_tokens=1000000)
|
||||
)
|
||||
assert fireworks_gpt.get_costs() == Costs(
|
||||
total_prompt_tokens=500000, total_completion_tokens=500000, total_cost=0.5, total_budget=0
|
||||
)
|
||||
|
||||
resp = await fireworks_gpt.acompletion(messages)
|
||||
assert resp.choices[0].message.content in resp_content
|
||||
|
||||
resp = await fireworks_gpt.aask(prompt_msg, stream=False)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await fireworks_gpt.acompletion_text(messages, stream=False)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await fireworks_gpt.acompletion_text(messages, stream=True)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await fireworks_gpt.aask(prompt_msg)
|
||||
assert resp == resp_content
|
||||
|
|
@ -11,6 +11,12 @@ from google.generativeai.types import content_types
|
|||
|
||||
from metagpt.provider.google_gemini_api import GeminiLLM
|
||||
from tests.metagpt.provider.mock_llm_config import mock_llm_config
|
||||
from tests.metagpt.provider.req_resp_const import (
|
||||
gemini_messages,
|
||||
llm_general_chat_funcs_test,
|
||||
prompt,
|
||||
resp_cont_tmpl,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -18,10 +24,8 @@ class MockGeminiResponse(ABC):
|
|||
text: str
|
||||
|
||||
|
||||
prompt_msg = "who are you"
|
||||
messages = [{"role": "user", "parts": prompt_msg}]
|
||||
resp_content = "I'm gemini from google"
|
||||
default_resp = MockGeminiResponse(text=resp_content)
|
||||
resp_cont = resp_cont_tmpl.format(name="gemini")
|
||||
default_resp = MockGeminiResponse(text=resp_cont)
|
||||
|
||||
|
||||
def mock_gemini_count_tokens(self, contents: content_types.ContentsType) -> glm.CountTokensResponse:
|
||||
|
|
@ -60,28 +64,18 @@ async def test_gemini_acompletion(mocker):
|
|||
mock_gemini_generate_content_async,
|
||||
)
|
||||
|
||||
gemini_gpt = GeminiLLM(mock_llm_config)
|
||||
gemini_llm = GeminiLLM(mock_llm_config)
|
||||
|
||||
assert gemini_gpt._user_msg(prompt_msg) == {"role": "user", "parts": [prompt_msg]}
|
||||
assert gemini_gpt._assistant_msg(prompt_msg) == {"role": "model", "parts": [prompt_msg]}
|
||||
assert gemini_llm._user_msg(prompt) == {"role": "user", "parts": [prompt]}
|
||||
assert gemini_llm._assistant_msg(prompt) == {"role": "model", "parts": [prompt]}
|
||||
|
||||
usage = gemini_gpt.get_usage(messages, resp_content)
|
||||
usage = gemini_llm.get_usage(gemini_messages, resp_cont)
|
||||
assert usage == {"prompt_tokens": 20, "completion_tokens": 20}
|
||||
|
||||
resp = gemini_gpt.completion(messages)
|
||||
resp = gemini_llm.completion(gemini_messages)
|
||||
assert resp == default_resp
|
||||
|
||||
resp = await gemini_gpt.acompletion(messages)
|
||||
resp = await gemini_llm.acompletion(gemini_messages)
|
||||
assert resp.text == default_resp.text
|
||||
|
||||
resp = await gemini_gpt.aask(prompt_msg, stream=False)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await gemini_gpt.acompletion_text(messages, stream=False)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await gemini_gpt.acompletion_text(messages, stream=True)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await gemini_gpt.aask(prompt_msg)
|
||||
assert resp == resp_content
|
||||
await llm_general_chat_funcs_test(gemini_llm, prompt, gemini_messages, resp_cont)
|
||||
|
|
|
|||
|
|
@ -9,12 +9,15 @@ import pytest
|
|||
|
||||
from metagpt.provider.ollama_api import OllamaLLM
|
||||
from tests.metagpt.provider.mock_llm_config import mock_llm_config
|
||||
from tests.metagpt.provider.req_resp_const import (
|
||||
llm_general_chat_funcs_test,
|
||||
messages,
|
||||
prompt,
|
||||
resp_cont_tmpl,
|
||||
)
|
||||
|
||||
prompt_msg = "who are you"
|
||||
messages = [{"role": "user", "content": prompt_msg}]
|
||||
|
||||
resp_content = "I'm ollama"
|
||||
default_resp = {"message": {"role": "assistant", "content": resp_content}}
|
||||
resp_cont = resp_cont_tmpl.format(name="ollama")
|
||||
default_resp = {"message": {"role": "assistant", "content": resp_cont}}
|
||||
|
||||
|
||||
async def mock_ollama_arequest(self, stream: bool = False, **kwargs) -> Tuple[Any, Any, bool]:
|
||||
|
|
@ -41,19 +44,12 @@ async def mock_ollama_arequest(self, stream: bool = False, **kwargs) -> Tuple[An
|
|||
async def test_gemini_acompletion(mocker):
|
||||
mocker.patch("metagpt.provider.general_api_requestor.GeneralAPIRequestor.arequest", mock_ollama_arequest)
|
||||
|
||||
ollama_gpt = OllamaLLM(mock_llm_config)
|
||||
ollama_llm = OllamaLLM(mock_llm_config)
|
||||
|
||||
resp = await ollama_gpt.acompletion(messages)
|
||||
resp = await ollama_llm.acompletion(messages)
|
||||
assert resp["message"]["content"] == default_resp["message"]["content"]
|
||||
|
||||
resp = await ollama_gpt.aask(prompt_msg, stream=False)
|
||||
assert resp == resp_content
|
||||
resp = await ollama_llm.aask(prompt, stream=False)
|
||||
assert resp == resp_cont
|
||||
|
||||
resp = await ollama_gpt.acompletion_text(messages, stream=False)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await ollama_gpt.acompletion_text(messages, stream=True)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await ollama_gpt.aask(prompt_msg)
|
||||
assert resp == resp_content
|
||||
await llm_general_chat_funcs_test(ollama_llm, prompt, messages, resp_cont)
|
||||
|
|
|
|||
|
|
@ -1,92 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
|
||||
import pytest
|
||||
from openai.types.chat.chat_completion import (
|
||||
ChatCompletion,
|
||||
ChatCompletionMessage,
|
||||
Choice,
|
||||
)
|
||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
from openai.types.chat.chat_completion_chunk import Choice as AChoice
|
||||
from openai.types.chat.chat_completion_chunk import ChoiceDelta
|
||||
from openai.types.completion_usage import CompletionUsage
|
||||
|
||||
from metagpt.provider.open_llm_api import OpenLLM
|
||||
from metagpt.utils.cost_manager import Costs
|
||||
from tests.metagpt.provider.mock_llm_config import mock_llm_config
|
||||
|
||||
resp_content = "I'm llama2"
|
||||
default_resp = ChatCompletion(
|
||||
id="cmpl-a6652c1bb181caae8dd19ad8",
|
||||
model="llama-v2-13b-chat",
|
||||
object="chat.completion",
|
||||
created=1703302755,
|
||||
choices=[
|
||||
Choice(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message=ChatCompletionMessage(role="assistant", content=resp_content),
|
||||
logprobs=None,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
default_resp_chunk = ChatCompletionChunk(
|
||||
id=default_resp.id,
|
||||
model=default_resp.model,
|
||||
object="chat.completion.chunk",
|
||||
created=default_resp.created,
|
||||
choices=[
|
||||
AChoice(
|
||||
delta=ChoiceDelta(content=resp_content, role="assistant"),
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
logprobs=None,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
prompt_msg = "who are you"
|
||||
messages = [{"role": "user", "content": prompt_msg}]
|
||||
|
||||
|
||||
async def mock_openai_acompletions_create(self, stream: bool = False, **kwargs) -> ChatCompletionChunk:
|
||||
if stream:
|
||||
|
||||
class Iterator(object):
|
||||
async def __aiter__(self):
|
||||
yield default_resp_chunk
|
||||
|
||||
return Iterator()
|
||||
else:
|
||||
return default_resp
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openllm_acompletion(mocker):
|
||||
mocker.patch("openai.resources.chat.completions.AsyncCompletions.create", mock_openai_acompletions_create)
|
||||
|
||||
openllm_gpt = OpenLLM(mock_llm_config)
|
||||
openllm_gpt.model = "llama-v2-13b-chat"
|
||||
|
||||
openllm_gpt._update_costs(usage=CompletionUsage(prompt_tokens=100, completion_tokens=100, total_tokens=200))
|
||||
assert openllm_gpt.get_costs() == Costs(
|
||||
total_prompt_tokens=100, total_completion_tokens=100, total_cost=0, total_budget=0
|
||||
)
|
||||
|
||||
resp = await openllm_gpt.acompletion(messages)
|
||||
assert resp.choices[0].message.content in resp_content
|
||||
|
||||
resp = await openllm_gpt.aask(prompt_msg, stream=False)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await openllm_gpt.acompletion_text(messages, stream=False)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await openllm_gpt.acompletion_text(messages, stream=True)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await openllm_gpt.aask(prompt_msg)
|
||||
assert resp == resp_content
|
||||
|
|
@ -1,12 +1,11 @@
|
|||
import json
|
||||
|
||||
import pytest
|
||||
from openai.types.chat import (
|
||||
ChatCompletion,
|
||||
ChatCompletionChunk,
|
||||
ChatCompletionMessage,
|
||||
ChatCompletionMessageToolCall,
|
||||
)
|
||||
from openai.types.chat.chat_completion import Choice
|
||||
from openai.types.chat.chat_completion import Choice, CompletionUsage
|
||||
from openai.types.chat.chat_completion_message_tool_call import Function
|
||||
from PIL import Image
|
||||
|
||||
|
|
@ -18,6 +17,22 @@ from tests.metagpt.provider.mock_llm_config import (
|
|||
mock_llm_config,
|
||||
mock_llm_config_proxy,
|
||||
)
|
||||
from tests.metagpt.provider.req_resp_const import (
|
||||
get_openai_chat_completion,
|
||||
get_openai_chat_completion_chunk,
|
||||
llm_general_chat_funcs_test,
|
||||
messages,
|
||||
prompt,
|
||||
resp_cont_tmpl,
|
||||
)
|
||||
|
||||
name = "AI assistant"
|
||||
resp_cont = resp_cont_tmpl.format(name=name)
|
||||
default_resp = get_openai_chat_completion(name)
|
||||
|
||||
default_resp_chunk = get_openai_chat_completion_chunk(name, usage_as_dict=True)
|
||||
|
||||
usage = CompletionUsage(completion_tokens=110, prompt_tokens=92, total_tokens=202)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -106,9 +121,11 @@ class TestOpenAI:
|
|||
|
||||
def test_aask_code_json_decode_error(self, json_decode_error):
|
||||
instance = OpenAILLM(mock_llm_config)
|
||||
with pytest.raises(json.decoder.JSONDecodeError) as e:
|
||||
instance.get_choice_function_arguments(json_decode_error)
|
||||
assert "JSONDecodeError" in str(e)
|
||||
code = instance.get_choice_function_arguments(json_decode_error)
|
||||
assert "code" in code
|
||||
assert "language" in code
|
||||
assert "hello world" in code["code"]
|
||||
logger.info(f'code is : {code["code"]}')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -121,3 +138,29 @@ async def test_gen_image():
|
|||
|
||||
images: list[Image] = await llm.gen_image(model=model, prompt=prompt, resp_format="b64_json")
|
||||
assert images[0].size == (1024, 1024)
|
||||
|
||||
|
||||
async def mock_openai_acompletions_create(self, stream: bool = False, **kwargs) -> ChatCompletionChunk:
|
||||
if stream:
|
||||
|
||||
class Iterator(object):
|
||||
async def __aiter__(self):
|
||||
yield default_resp_chunk
|
||||
|
||||
return Iterator()
|
||||
else:
|
||||
return default_resp
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_acompletion(mocker):
|
||||
mocker.patch("openai.resources.chat.completions.AsyncCompletions.create", mock_openai_acompletions_create)
|
||||
|
||||
llm = OpenAILLM(mock_llm_config)
|
||||
|
||||
resp = await llm.acompletion(messages)
|
||||
assert resp.choices[0].finish_reason == "stop"
|
||||
assert resp.choices[0].message.content == resp_cont
|
||||
assert resp.usage == usage
|
||||
|
||||
await llm_general_chat_funcs_test(llm, prompt, messages, resp_cont)
|
||||
|
|
|
|||
56
tests/metagpt/provider/test_qianfan_api.py
Normal file
56
tests/metagpt/provider/test_qianfan_api.py
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the unittest of qianfan api
|
||||
|
||||
from typing import AsyncIterator, Union
|
||||
|
||||
import pytest
|
||||
from qianfan.resources.typing import JsonBody, QfResponse
|
||||
|
||||
from metagpt.provider.qianfan_api import QianFanLLM
|
||||
from tests.metagpt.provider.mock_llm_config import mock_llm_config_qianfan
|
||||
from tests.metagpt.provider.req_resp_const import (
|
||||
get_qianfan_response,
|
||||
llm_general_chat_funcs_test,
|
||||
messages,
|
||||
prompt,
|
||||
resp_cont_tmpl,
|
||||
)
|
||||
|
||||
name = "ERNIE-Bot-turbo"
|
||||
resp_cont = resp_cont_tmpl.format(name=name)
|
||||
|
||||
|
||||
def mock_qianfan_do(self, messages: list[dict], model: str, stream: bool = False, system: str = None) -> QfResponse:
|
||||
return get_qianfan_response(name=name)
|
||||
|
||||
|
||||
async def mock_qianfan_ado(
|
||||
self, messages: list[dict], model: str, stream: bool = True, system: str = None
|
||||
) -> Union[QfResponse, AsyncIterator[QfResponse]]:
|
||||
resps = [get_qianfan_response(name=name)]
|
||||
if stream:
|
||||
|
||||
async def aresp_iterator(resps: list[JsonBody]):
|
||||
for resp in resps:
|
||||
yield resp
|
||||
|
||||
return aresp_iterator(resps)
|
||||
else:
|
||||
return resps[0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_qianfan_acompletion(mocker):
|
||||
mocker.patch("qianfan.resources.llm.chat_completion.ChatCompletion.do", mock_qianfan_do)
|
||||
mocker.patch("qianfan.resources.llm.chat_completion.ChatCompletion.ado", mock_qianfan_ado)
|
||||
|
||||
qianfan_llm = QianFanLLM(mock_llm_config_qianfan)
|
||||
|
||||
resp = qianfan_llm.completion(messages)
|
||||
assert resp.get("result") == resp_cont
|
||||
|
||||
resp = await qianfan_llm.acompletion(messages)
|
||||
assert resp.get("result") == resp_cont
|
||||
|
||||
await llm_general_chat_funcs_test(qianfan_llm, prompt, messages, resp_cont)
|
||||
|
|
@ -4,12 +4,18 @@
|
|||
|
||||
import pytest
|
||||
|
||||
from metagpt.config2 import Config
|
||||
from metagpt.provider.spark_api import GetMessageFromWeb, SparkLLM
|
||||
from tests.metagpt.provider.mock_llm_config import mock_llm_config
|
||||
from tests.metagpt.provider.mock_llm_config import (
|
||||
mock_llm_config,
|
||||
mock_llm_config_spark,
|
||||
)
|
||||
from tests.metagpt.provider.req_resp_const import (
|
||||
llm_general_chat_funcs_test,
|
||||
prompt,
|
||||
resp_cont_tmpl,
|
||||
)
|
||||
|
||||
prompt_msg = "who are you"
|
||||
resp_content = "I'm Spark"
|
||||
resp_cont = resp_cont_tmpl.format(name="Spark")
|
||||
|
||||
|
||||
class MockWebSocketApp(object):
|
||||
|
|
@ -23,7 +29,7 @@ class MockWebSocketApp(object):
|
|||
def test_get_msg_from_web(mocker):
|
||||
mocker.patch("websocket.WebSocketApp", MockWebSocketApp)
|
||||
|
||||
get_msg_from_web = GetMessageFromWeb(prompt_msg, mock_llm_config)
|
||||
get_msg_from_web = GetMessageFromWeb(prompt, mock_llm_config)
|
||||
assert get_msg_from_web.gen_params()["parameter"]["chat"]["domain"] == "mock_domain"
|
||||
|
||||
ret = get_msg_from_web.run()
|
||||
|
|
@ -31,34 +37,26 @@ def test_get_msg_from_web(mocker):
|
|||
|
||||
|
||||
def mock_spark_get_msg_from_web_run(self) -> str:
|
||||
return resp_content
|
||||
return resp_cont
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_spark_aask():
|
||||
llm = SparkLLM(Config.from_home("spark.yaml").llm)
|
||||
async def test_spark_aask(mocker):
|
||||
mocker.patch("metagpt.provider.spark_api.GetMessageFromWeb.run", mock_spark_get_msg_from_web_run)
|
||||
|
||||
llm = SparkLLM(mock_llm_config_spark)
|
||||
|
||||
resp = await llm.aask("Hello!")
|
||||
print(resp)
|
||||
assert resp == resp_cont
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_spark_acompletion(mocker):
|
||||
mocker.patch("metagpt.provider.spark_api.GetMessageFromWeb.run", mock_spark_get_msg_from_web_run)
|
||||
|
||||
spark_gpt = SparkLLM(mock_llm_config)
|
||||
spark_llm = SparkLLM(mock_llm_config)
|
||||
|
||||
resp = await spark_gpt.acompletion([])
|
||||
assert resp == resp_content
|
||||
resp = await spark_llm.acompletion([])
|
||||
assert resp == resp_cont
|
||||
|
||||
resp = await spark_gpt.aask(prompt_msg, stream=False)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await spark_gpt.acompletion_text([], stream=False)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await spark_gpt.acompletion_text([], stream=True)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await spark_gpt.aask(prompt_msg)
|
||||
assert resp == resp_content
|
||||
await llm_general_chat_funcs_test(spark_llm, prompt, prompt, resp_cont)
|
||||
|
|
|
|||
|
|
@ -6,22 +6,24 @@ import pytest
|
|||
|
||||
from metagpt.provider.zhipuai_api import ZhiPuAILLM
|
||||
from tests.metagpt.provider.mock_llm_config import mock_llm_config_zhipu
|
||||
from tests.metagpt.provider.req_resp_const import (
|
||||
get_part_chat_completion,
|
||||
llm_general_chat_funcs_test,
|
||||
messages,
|
||||
prompt,
|
||||
resp_cont_tmpl,
|
||||
)
|
||||
|
||||
prompt_msg = "who are you"
|
||||
messages = [{"role": "user", "content": prompt_msg}]
|
||||
|
||||
resp_content = "I'm chatglm-turbo"
|
||||
default_resp = {
|
||||
"choices": [{"finish_reason": "stop", "index": 0, "message": {"content": resp_content, "role": "assistant"}}],
|
||||
"usage": {"completion_tokens": 22, "prompt_tokens": 19, "total_tokens": 41},
|
||||
}
|
||||
name = "ChatGLM-4"
|
||||
resp_cont = resp_cont_tmpl.format(name=name)
|
||||
default_resp = get_part_chat_completion(name)
|
||||
|
||||
|
||||
async def mock_zhipuai_acreate_stream(**kwargs):
|
||||
async def mock_zhipuai_acreate_stream(self, **kwargs):
|
||||
class MockResponse(object):
|
||||
async def _aread(self):
|
||||
class Iterator(object):
|
||||
events = [{"choices": [{"index": 0, "delta": {"content": resp_content, "role": "assistant"}}]}]
|
||||
events = [{"choices": [{"index": 0, "delta": {"content": resp_cont, "role": "assistant"}}]}]
|
||||
|
||||
async def __aiter__(self):
|
||||
for event in self.events:
|
||||
|
|
@ -37,7 +39,7 @@ async def mock_zhipuai_acreate_stream(**kwargs):
|
|||
return MockResponse()
|
||||
|
||||
|
||||
async def mock_zhipuai_acreate(**kwargs) -> dict:
|
||||
async def mock_zhipuai_acreate(self, **kwargs) -> dict:
|
||||
return default_resp
|
||||
|
||||
|
||||
|
|
@ -46,22 +48,12 @@ async def test_zhipuai_acompletion(mocker):
|
|||
mocker.patch("metagpt.provider.zhipuai.zhipu_model_api.ZhiPuModelAPI.acreate", mock_zhipuai_acreate)
|
||||
mocker.patch("metagpt.provider.zhipuai.zhipu_model_api.ZhiPuModelAPI.acreate_stream", mock_zhipuai_acreate_stream)
|
||||
|
||||
zhipu_gpt = ZhiPuAILLM(mock_llm_config_zhipu)
|
||||
zhipu_llm = ZhiPuAILLM(mock_llm_config_zhipu)
|
||||
|
||||
resp = await zhipu_gpt.acompletion(messages)
|
||||
assert resp["choices"][0]["message"]["content"] == resp_content
|
||||
resp = await zhipu_llm.acompletion(messages)
|
||||
assert resp["choices"][0]["message"]["content"] == resp_cont
|
||||
|
||||
resp = await zhipu_gpt.aask(prompt_msg, stream=False)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await zhipu_gpt.acompletion_text(messages, stream=False)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await zhipu_gpt.acompletion_text(messages, stream=True)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await zhipu_gpt.aask(prompt_msg)
|
||||
assert resp == resp_content
|
||||
await llm_general_chat_funcs_test(zhipu_llm, prompt, messages, resp_cont)
|
||||
|
||||
|
||||
def test_zhipuai_proxy():
|
||||
|
|
|
|||
166
tests/metagpt/rag/engines/test_simple.py
Normal file
166
tests/metagpt/rag/engines/test_simple.py
Normal file
|
|
@ -0,0 +1,166 @@
|
|||
import pytest
|
||||
from llama_index.core import VectorStoreIndex
|
||||
from llama_index.core.schema import Document, TextNode
|
||||
|
||||
from metagpt.rag.engines import SimpleEngine
|
||||
from metagpt.rag.retrievers.base import ModifiableRAGRetriever
|
||||
|
||||
|
||||
class TestSimpleEngine:
|
||||
@pytest.fixture
|
||||
def mock_simple_directory_reader(self, mocker):
|
||||
return mocker.patch("metagpt.rag.engines.simple.SimpleDirectoryReader")
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vector_store_index(self, mocker):
|
||||
return mocker.patch("metagpt.rag.engines.simple.VectorStoreIndex.from_documents")
|
||||
|
||||
@pytest.fixture
|
||||
def mock_get_retriever(self, mocker):
|
||||
return mocker.patch("metagpt.rag.engines.simple.get_retriever")
|
||||
|
||||
@pytest.fixture
|
||||
def mock_get_rankers(self, mocker):
|
||||
return mocker.patch("metagpt.rag.engines.simple.get_rankers")
|
||||
|
||||
@pytest.fixture
|
||||
def mock_get_response_synthesizer(self, mocker):
|
||||
return mocker.patch("metagpt.rag.engines.simple.get_response_synthesizer")
|
||||
|
||||
def test_from_docs(
|
||||
self,
|
||||
mocker,
|
||||
mock_simple_directory_reader,
|
||||
mock_vector_store_index,
|
||||
mock_get_retriever,
|
||||
mock_get_rankers,
|
||||
mock_get_response_synthesizer,
|
||||
):
|
||||
# Mock
|
||||
mock_simple_directory_reader.return_value.load_data.return_value = [
|
||||
Document(text="document1"),
|
||||
Document(text="document2"),
|
||||
]
|
||||
mock_get_retriever.return_value = mocker.MagicMock()
|
||||
mock_get_rankers.return_value = [mocker.MagicMock()]
|
||||
mock_get_response_synthesizer.return_value = mocker.MagicMock()
|
||||
|
||||
# Setup
|
||||
input_dir = "test_dir"
|
||||
input_files = ["test_file1", "test_file2"]
|
||||
transformations = [mocker.MagicMock()]
|
||||
embed_model = mocker.MagicMock()
|
||||
llm = mocker.MagicMock()
|
||||
retriever_configs = [mocker.MagicMock()]
|
||||
ranker_configs = [mocker.MagicMock()]
|
||||
|
||||
# Execute
|
||||
engine = SimpleEngine.from_docs(
|
||||
input_dir=input_dir,
|
||||
input_files=input_files,
|
||||
transformations=transformations,
|
||||
embed_model=embed_model,
|
||||
llm=llm,
|
||||
retriever_configs=retriever_configs,
|
||||
ranker_configs=ranker_configs,
|
||||
)
|
||||
|
||||
# Assertions
|
||||
mock_simple_directory_reader.assert_called_once_with(input_dir=input_dir, input_files=input_files)
|
||||
mock_vector_store_index.assert_called_once()
|
||||
mock_get_retriever.assert_called_once_with(
|
||||
configs=retriever_configs, index=mock_vector_store_index.return_value
|
||||
)
|
||||
mock_get_rankers.assert_called_once_with(configs=ranker_configs, llm=llm)
|
||||
mock_get_response_synthesizer.assert_called_once_with(llm=llm)
|
||||
assert isinstance(engine, SimpleEngine)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_asearch(self, mocker):
|
||||
# Mock
|
||||
test_query = "test query"
|
||||
expected_result = "expected result"
|
||||
mock_aquery = mocker.AsyncMock(return_value=expected_result)
|
||||
|
||||
# Setup
|
||||
engine = SimpleEngine(retriever=mocker.MagicMock())
|
||||
engine.aquery = mock_aquery
|
||||
|
||||
# Execute
|
||||
result = await engine.asearch(test_query)
|
||||
|
||||
# Assertions
|
||||
mock_aquery.assert_called_once_with(test_query)
|
||||
assert result == expected_result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aretrieve(self, mocker):
|
||||
# Mock
|
||||
mock_query_bundle = mocker.patch("metagpt.rag.engines.simple.QueryBundle", return_value="query_bundle")
|
||||
mock_super_aretrieve = mocker.patch(
|
||||
"metagpt.rag.engines.simple.RetrieverQueryEngine.aretrieve", new_callable=mocker.AsyncMock
|
||||
)
|
||||
mock_super_aretrieve.return_value = [TextNode(text="node_with_score", metadata={"is_obj": False})]
|
||||
|
||||
# Setup
|
||||
engine = SimpleEngine(retriever=mocker.MagicMock())
|
||||
test_query = "test query"
|
||||
|
||||
# Execute
|
||||
result = await engine.aretrieve(test_query)
|
||||
|
||||
# Assertions
|
||||
mock_query_bundle.assert_called_once_with(test_query)
|
||||
mock_super_aretrieve.assert_called_once_with("query_bundle")
|
||||
assert result[0].text == "node_with_score"
|
||||
|
||||
def test_add_docs(self, mocker):
|
||||
# Mock
|
||||
mock_simple_directory_reader = mocker.patch("metagpt.rag.engines.simple.SimpleDirectoryReader")
|
||||
mock_simple_directory_reader.return_value.load_data.return_value = [
|
||||
Document(text="document1"),
|
||||
Document(text="document2"),
|
||||
]
|
||||
|
||||
mock_retriever = mocker.MagicMock(spec=ModifiableRAGRetriever)
|
||||
|
||||
mock_index = mocker.MagicMock(spec=VectorStoreIndex)
|
||||
mock_index._transformations = mocker.MagicMock()
|
||||
|
||||
mock_run_transformations = mocker.patch("metagpt.rag.engines.simple.run_transformations")
|
||||
mock_run_transformations.return_value = ["node1", "node2"]
|
||||
|
||||
# Setup
|
||||
engine = SimpleEngine(retriever=mock_retriever, index=mock_index)
|
||||
input_files = ["test_file1", "test_file2"]
|
||||
|
||||
# Execute
|
||||
engine.add_docs(input_files=input_files)
|
||||
|
||||
# Assertions
|
||||
mock_simple_directory_reader.assert_called_once_with(input_files=input_files)
|
||||
mock_retriever.add_nodes.assert_called_once_with(["node1", "node2"])
|
||||
|
||||
def test_add_objs(self, mocker):
|
||||
# Mock
|
||||
mock_retriever = mocker.MagicMock(spec=ModifiableRAGRetriever)
|
||||
|
||||
# Setup
|
||||
class CustomTextNode(TextNode):
|
||||
def rag_key(self):
|
||||
return ""
|
||||
|
||||
def model_dump_json(self):
|
||||
return ""
|
||||
|
||||
objs = [CustomTextNode(text=f"text_{i}", metadata={"obj": f"obj_{i}"}) for i in range(2)]
|
||||
engine = SimpleEngine(retriever=mock_retriever, index=mocker.MagicMock())
|
||||
|
||||
# Execute
|
||||
engine.add_objs(objs=objs)
|
||||
|
||||
# Assertions
|
||||
assert mock_retriever.add_nodes.call_count == 1
|
||||
for node in mock_retriever.add_nodes.call_args[0][0]:
|
||||
assert isinstance(node, TextNode)
|
||||
assert "is_obj" in node.metadata
|
||||
102
tests/metagpt/rag/factories/test_base.py
Normal file
102
tests/metagpt/rag/factories/test_base.py
Normal file
|
|
@ -0,0 +1,102 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.rag.factories.base import ConfigBasedFactory, GenericFactory
|
||||
|
||||
|
||||
class TestGenericFactory:
|
||||
@pytest.fixture
|
||||
def creators(self):
|
||||
return {
|
||||
"type1": lambda name: f"Instance of type1 with {name}",
|
||||
"type2": lambda name: f"Instance of type2 with {name}",
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def factory(self, creators):
|
||||
return GenericFactory(creators=creators)
|
||||
|
||||
def test_get_instance_success(self, factory):
|
||||
# Test successful retrieval of an instance
|
||||
key = "type1"
|
||||
instance = factory.get_instance(key, name="TestName")
|
||||
assert instance == "Instance of type1 with TestName"
|
||||
|
||||
def test_get_instance_failure(self, factory):
|
||||
# Test failure to retrieve an instance due to unregistered key
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
factory.get_instance("unknown_key")
|
||||
assert "Creator not registered for key: unknown_key" in str(exc_info.value)
|
||||
|
||||
def test_get_instances_success(self, factory):
|
||||
# Test successful retrieval of multiple instances
|
||||
keys = ["type1", "type2"]
|
||||
instances = factory.get_instances(keys, name="TestName")
|
||||
expected = ["Instance of type1 with TestName", "Instance of type2 with TestName"]
|
||||
assert instances == expected
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"keys,expected_exception_message",
|
||||
[
|
||||
(["unknown_key"], "Creator not registered for key: unknown_key"),
|
||||
(["type1", "unknown_key"], "Creator not registered for key: unknown_key"),
|
||||
],
|
||||
)
|
||||
def test_get_instances_with_failure(self, factory, keys, expected_exception_message):
|
||||
# Test failure to retrieve instances due to at least one unregistered key
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
factory.get_instances(keys, name="TestName")
|
||||
assert expected_exception_message in str(exc_info.value)
|
||||
|
||||
|
||||
class DummyConfig:
|
||||
"""A dummy config class for testing."""
|
||||
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
|
||||
class TestConfigBasedFactory:
|
||||
@pytest.fixture
|
||||
def config_creators(self):
|
||||
return {
|
||||
DummyConfig: lambda config, **kwargs: f"Processed {config.name} with {kwargs.get('extra', 'no extra')}",
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def config_factory(self, config_creators):
|
||||
return ConfigBasedFactory(creators=config_creators)
|
||||
|
||||
def test_get_instance_success(self, config_factory):
|
||||
# Test successful retrieval of an instance
|
||||
config = DummyConfig(name="TestConfig")
|
||||
instance = config_factory.get_instance(config, extra="additional data")
|
||||
assert instance == "Processed TestConfig with additional data"
|
||||
|
||||
def test_get_instance_failure(self, config_factory):
|
||||
# Test failure to retrieve an instance due to unknown config type
|
||||
class UnknownConfig:
|
||||
pass
|
||||
|
||||
config = UnknownConfig()
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
config_factory.get_instance(config)
|
||||
assert "Unknown config:" in str(exc_info.value)
|
||||
|
||||
def test_val_from_config_or_kwargs_priority(self):
|
||||
# Test that the value from the config object has priority over kwargs
|
||||
config = DummyConfig(name="ConfigName")
|
||||
result = ConfigBasedFactory._val_from_config_or_kwargs("name", config, name="KwargsName")
|
||||
assert result == "ConfigName"
|
||||
|
||||
def test_val_from_config_or_kwargs_fallback_to_kwargs(self):
|
||||
# Test fallback to kwargs when config object does not have the value
|
||||
config = DummyConfig(name=None)
|
||||
result = ConfigBasedFactory._val_from_config_or_kwargs("name", config, name="KwargsName")
|
||||
assert result == "KwargsName"
|
||||
|
||||
def test_val_from_config_or_kwargs_key_error(self):
|
||||
# Test KeyError when the key is not found in both config object and kwargs
|
||||
config = DummyConfig(name=None)
|
||||
with pytest.raises(KeyError) as exc_info:
|
||||
ConfigBasedFactory._val_from_config_or_kwargs("missing_key", config)
|
||||
assert "The key 'missing_key' is required but not provided" in str(exc_info.value)
|
||||
41
tests/metagpt/rag/factories/test_ranker.py
Normal file
41
tests/metagpt/rag/factories/test_ranker.py
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
import pytest
|
||||
from llama_index.core.llms import LLM
|
||||
from llama_index.core.postprocessor import LLMRerank
|
||||
|
||||
from metagpt.rag.factories.ranker import RankerFactory
|
||||
from metagpt.rag.schema import LLMRankerConfig
|
||||
|
||||
|
||||
class TestRankerFactory:
|
||||
@pytest.fixture
|
||||
def ranker_factory(self) -> RankerFactory:
|
||||
return RankerFactory()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm(self, mocker):
|
||||
return mocker.MagicMock(spec=LLM)
|
||||
|
||||
def test_get_rankers_with_no_configs(self, ranker_factory: RankerFactory, mock_llm, mocker):
|
||||
mocker.patch.object(ranker_factory, "_extract_llm", return_value=mock_llm)
|
||||
default_rankers = ranker_factory.get_rankers()
|
||||
assert len(default_rankers) == 0
|
||||
|
||||
def test_get_rankers_with_configs(self, ranker_factory: RankerFactory, mock_llm):
|
||||
mock_config = LLMRankerConfig(llm=mock_llm)
|
||||
rankers = ranker_factory.get_rankers(configs=[mock_config])
|
||||
assert len(rankers) == 1
|
||||
assert isinstance(rankers[0], LLMRerank)
|
||||
|
||||
def test_create_llm_ranker_creates_correct_instance(self, ranker_factory: RankerFactory, mock_llm):
|
||||
mock_config = LLMRankerConfig(llm=mock_llm)
|
||||
ranker = ranker_factory._create_llm_ranker(mock_config)
|
||||
assert isinstance(ranker, LLMRerank)
|
||||
|
||||
def test_extract_llm_from_config(self, ranker_factory: RankerFactory, mock_llm):
|
||||
mock_config = LLMRankerConfig(llm=mock_llm)
|
||||
extracted_llm = ranker_factory._extract_llm(config=mock_config)
|
||||
assert extracted_llm == mock_llm
|
||||
|
||||
def test_extract_llm_from_kwargs(self, ranker_factory: RankerFactory, mock_llm):
|
||||
extracted_llm = ranker_factory._extract_llm(llm=mock_llm)
|
||||
assert extracted_llm == mock_llm
|
||||
79
tests/metagpt/rag/factories/test_retriever.py
Normal file
79
tests/metagpt/rag/factories/test_retriever.py
Normal file
|
|
@ -0,0 +1,79 @@
|
|||
import faiss
|
||||
import pytest
|
||||
from llama_index.core import VectorStoreIndex
|
||||
|
||||
from metagpt.rag.factories.retriever import RetrieverFactory
|
||||
from metagpt.rag.retrievers.bm25_retriever import DynamicBM25Retriever
|
||||
from metagpt.rag.retrievers.faiss_retriever import FAISSRetriever
|
||||
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
|
||||
from metagpt.rag.schema import BM25RetrieverConfig, FAISSRetrieverConfig
|
||||
|
||||
|
||||
class TestRetrieverFactory:
|
||||
@pytest.fixture
|
||||
def retriever_factory(self):
|
||||
return RetrieverFactory()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_faiss_index(self, mocker):
|
||||
return mocker.MagicMock(spec=faiss.IndexFlatL2)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vector_store_index(self, mocker):
|
||||
mock = mocker.MagicMock(spec=VectorStoreIndex)
|
||||
mock._embed_model = mocker.MagicMock()
|
||||
mock.docstore.docs.values.return_value = []
|
||||
return mock
|
||||
|
||||
def test_get_retriever_with_faiss_config(
|
||||
self, retriever_factory: RetrieverFactory, mock_faiss_index, mocker, mock_vector_store_index
|
||||
):
|
||||
mock_config = FAISSRetrieverConfig(dimensions=128)
|
||||
mocker.patch("faiss.IndexFlatL2", return_value=mock_faiss_index)
|
||||
mocker.patch.object(retriever_factory, "_extract_index", return_value=mock_vector_store_index)
|
||||
|
||||
retriever = retriever_factory.get_retriever(configs=[mock_config])
|
||||
|
||||
assert isinstance(retriever, FAISSRetriever)
|
||||
|
||||
def test_get_retriever_with_bm25_config(self, retriever_factory: RetrieverFactory, mocker, mock_vector_store_index):
|
||||
mock_config = BM25RetrieverConfig()
|
||||
mocker.patch("rank_bm25.BM25Okapi.__init__", return_value=None)
|
||||
mocker.patch.object(retriever_factory, "_extract_index", return_value=mock_vector_store_index)
|
||||
|
||||
retriever = retriever_factory.get_retriever(configs=[mock_config])
|
||||
|
||||
assert isinstance(retriever, DynamicBM25Retriever)
|
||||
|
||||
def test_get_retriever_with_multiple_configs_returns_hybrid(
|
||||
self, retriever_factory: RetrieverFactory, mocker, mock_vector_store_index
|
||||
):
|
||||
mock_faiss_config = FAISSRetrieverConfig(dimensions=128)
|
||||
mock_bm25_config = BM25RetrieverConfig()
|
||||
mocker.patch("rank_bm25.BM25Okapi.__init__", return_value=None)
|
||||
mocker.patch.object(retriever_factory, "_extract_index", return_value=mock_vector_store_index)
|
||||
|
||||
retriever = retriever_factory.get_retriever(configs=[mock_faiss_config, mock_bm25_config])
|
||||
|
||||
assert isinstance(retriever, SimpleHybridRetriever)
|
||||
|
||||
def test_create_default_retriever(self, retriever_factory: RetrieverFactory, mocker, mock_vector_store_index):
|
||||
mocker.patch.object(retriever_factory, "_extract_index", return_value=mock_vector_store_index)
|
||||
mock_vector_store_index.as_retriever = mocker.MagicMock()
|
||||
|
||||
retriever = retriever_factory.get_retriever()
|
||||
|
||||
mock_vector_store_index.as_retriever.assert_called_once()
|
||||
assert retriever is mock_vector_store_index.as_retriever.return_value
|
||||
|
||||
def test_extract_index_from_config(self, retriever_factory: RetrieverFactory, mock_vector_store_index):
|
||||
mock_config = FAISSRetrieverConfig(index=mock_vector_store_index)
|
||||
|
||||
extracted_index = retriever_factory._extract_index(config=mock_config)
|
||||
|
||||
assert extracted_index == mock_vector_store_index
|
||||
|
||||
def test_extract_index_from_kwargs(self, retriever_factory: RetrieverFactory, mock_vector_store_index):
|
||||
extracted_index = retriever_factory._extract_index(index=mock_vector_store_index)
|
||||
|
||||
assert extracted_index == mock_vector_store_index
|
||||
37
tests/metagpt/rag/retrievers/test_bm25_retriever.py
Normal file
37
tests/metagpt/rag/retrievers/test_bm25_retriever.py
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
import pytest
|
||||
from llama_index.core import VectorStoreIndex
|
||||
from llama_index.core.schema import Node
|
||||
|
||||
from metagpt.rag.retrievers.bm25_retriever import DynamicBM25Retriever
|
||||
|
||||
|
||||
class TestDynamicBM25Retriever:
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup(self, mocker):
|
||||
# 创建模拟的Document对象
|
||||
self.doc1 = mocker.MagicMock(spec=Node)
|
||||
self.doc1.get_content.return_value = "Document content 1"
|
||||
self.doc2 = mocker.MagicMock(spec=Node)
|
||||
self.doc2.get_content.return_value = "Document content 2"
|
||||
self.mock_nodes = [self.doc1, self.doc2]
|
||||
|
||||
# 模拟index
|
||||
index = mocker.MagicMock(spec=VectorStoreIndex)
|
||||
|
||||
# 模拟nodes和tokenizer参数
|
||||
mock_nodes = []
|
||||
mock_tokenizer = mocker.MagicMock()
|
||||
self.mock_bm25okapi = mocker.patch("rank_bm25.BM25Okapi.__init__", return_value=None)
|
||||
|
||||
# 初始化DynamicBM25Retriever对象,并提供必需的参数
|
||||
self.retriever = DynamicBM25Retriever(nodes=mock_nodes, tokenizer=mock_tokenizer, index=index)
|
||||
|
||||
def test_add_docs_updates_nodes_and_corpus(self):
|
||||
# Execute
|
||||
self.retriever.add_nodes(self.mock_nodes)
|
||||
|
||||
# Assertions
|
||||
assert len(self.retriever._nodes) == len(self.mock_nodes)
|
||||
assert len(self.retriever._corpus) == len(self.mock_nodes)
|
||||
self.retriever._tokenizer.assert_called()
|
||||
self.mock_bm25okapi.assert_called()
|
||||
22
tests/metagpt/rag/retrievers/test_faiss_retriever.py
Normal file
22
tests/metagpt/rag/retrievers/test_faiss_retriever.py
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
import pytest
|
||||
from llama_index.core.schema import Node
|
||||
|
||||
from metagpt.rag.retrievers.faiss_retriever import FAISSRetriever
|
||||
|
||||
|
||||
class TestFAISSRetriever:
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup(self, mocker):
|
||||
# 创建模拟的Document对象
|
||||
self.doc1 = mocker.MagicMock(spec=Node)
|
||||
self.doc2 = mocker.MagicMock(spec=Node)
|
||||
self.mock_nodes = [self.doc1, self.doc2]
|
||||
|
||||
# 模拟FAISSRetriever的_index属性
|
||||
self.mock_index = mocker.MagicMock()
|
||||
self.retriever = FAISSRetriever(self.mock_index)
|
||||
|
||||
def test_add_docs_calls_insert_for_each_document(self, mocker):
|
||||
self.retriever.add_nodes(self.mock_nodes)
|
||||
|
||||
assert self.mock_index.insert_nodes.assert_called
|
||||
39
tests/metagpt/rag/retrievers/test_hybrid_retriever.py
Normal file
39
tests/metagpt/rag/retrievers/test_hybrid_retriever.py
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
from llama_index.core.schema import NodeWithScore, TextNode
|
||||
|
||||
from metagpt.rag.retrievers import SimpleHybridRetriever
|
||||
|
||||
|
||||
class TestSimpleHybridRetriever:
|
||||
@pytest.mark.asyncio
|
||||
async def test_aretrieve(self):
|
||||
question = "test query"
|
||||
|
||||
# Create mock retrievers
|
||||
mock_retriever1 = AsyncMock()
|
||||
mock_retriever1.aretrieve.return_value = [
|
||||
NodeWithScore(node=TextNode(id_="1"), score=1.0),
|
||||
NodeWithScore(node=TextNode(id_="2"), score=0.95),
|
||||
]
|
||||
|
||||
mock_retriever2 = AsyncMock()
|
||||
mock_retriever2.aretrieve.return_value = [
|
||||
NodeWithScore(node=TextNode(id_="2"), score=0.95),
|
||||
NodeWithScore(node=TextNode(id_="3"), score=0.8),
|
||||
]
|
||||
|
||||
# Instantiate the SimpleHybridRetriever with the mock retrievers
|
||||
hybrid_retriever = SimpleHybridRetriever(mock_retriever1, mock_retriever2)
|
||||
|
||||
# Call the _aretrieve method
|
||||
results = await hybrid_retriever._aretrieve(question)
|
||||
|
||||
# Check if the results are as expected
|
||||
assert len(results) == 3 # Should be 3 unique nodes
|
||||
assert set(node.node.node_id for node in results) == {"1", "2", "3"}
|
||||
|
||||
# Check if the scores are correct (assuming you want the highest score)
|
||||
node_scores = {node.node.node_id: node.score for node in results}
|
||||
assert node_scores["2"] == 0.95
|
||||
|
|
@ -1,19 +0,0 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.logs import logger
|
||||
from metagpt.roles.ci.code_interpreter import CodeInterpreter
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("auto_run", [(True), (False)])
|
||||
async def test_code_interpreter(mocker, auto_run):
|
||||
mocker.patch("metagpt.actions.ci.execute_nb_code.ExecuteNbCode.run", return_value=("a successful run", True))
|
||||
mocker.patch("builtins.input", return_value="confirm")
|
||||
|
||||
requirement = "Run data analysis on sklearn Iris dataset, include a plot"
|
||||
tools = []
|
||||
|
||||
ci = CodeInterpreter(auto_run=auto_run, use_tools=True, tools=tools)
|
||||
rsp = await ci.run(requirement)
|
||||
logger.info(rsp)
|
||||
assert len(rsp.content) > 0
|
||||
|
|
@ -1,90 +0,0 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.actions.ci.execute_nb_code import ExecuteNbCode
|
||||
from metagpt.logs import logger
|
||||
from metagpt.roles.ci.ml_engineer import MLEngineer
|
||||
from metagpt.schema import Message, Plan, Task
|
||||
from metagpt.tools.tool_type import ToolType
|
||||
from tests.metagpt.actions.ci.test_debug_code import CODE, DebugContext, ErrorStr
|
||||
|
||||
|
||||
def test_mle_init():
|
||||
ci = MLEngineer(goal="test", auto_run=True, use_tools=True, tools=["tool1", "tool2"])
|
||||
assert ci.tools == []
|
||||
|
||||
|
||||
MockPlan = Plan(
|
||||
goal="This is a titanic passenger survival dataset, your goal is to predict passenger survival outcome. The target column is Survived. Perform data analysis, data preprocessing, feature engineering, and modeling to predict the target. Report accuracy on the eval data. Train data path: 'tests/data/ml_datasets/titanic/split_train.csv', eval data path: 'tests/data/ml_datasets/titanic/split_eval.csv'.",
|
||||
context="",
|
||||
tasks=[
|
||||
Task(
|
||||
task_id="1",
|
||||
dependent_task_ids=[],
|
||||
instruction="Perform exploratory data analysis on the train dataset to understand the features and target variable.",
|
||||
task_type="eda",
|
||||
code="",
|
||||
result="",
|
||||
is_success=False,
|
||||
is_finished=False,
|
||||
)
|
||||
],
|
||||
task_map={
|
||||
"1": Task(
|
||||
task_id="1",
|
||||
dependent_task_ids=[],
|
||||
instruction="Perform exploratory data analysis on the train dataset to understand the features and target variable.",
|
||||
task_type="eda",
|
||||
code="",
|
||||
result="",
|
||||
is_success=False,
|
||||
is_finished=False,
|
||||
)
|
||||
},
|
||||
current_task_id="1",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mle_write_code(mocker):
|
||||
data_path = "tests/data/ml_datasets/titanic"
|
||||
|
||||
mle = MLEngineer(auto_run=True, use_tools=True)
|
||||
mle.planner.plan = MockPlan
|
||||
|
||||
code, _ = await mle._write_code()
|
||||
assert data_path in code["code"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mle_update_data_columns(mocker):
|
||||
mle = MLEngineer(auto_run=True, use_tools=True)
|
||||
mle.planner.plan = MockPlan
|
||||
|
||||
# manually update task type to test update
|
||||
mle.planner.plan.current_task.task_type = ToolType.DATA_PREPROCESS.value
|
||||
|
||||
result = await mle._update_data_columns()
|
||||
assert result is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mle_debug_code(mocker):
|
||||
mle = MLEngineer(auto_run=True, use_tools=True)
|
||||
mle.working_memory.add(Message(content=ErrorStr, cause_by=ExecuteNbCode))
|
||||
mle.latest_code = CODE
|
||||
mle.debug_context = DebugContext
|
||||
code, _ = await mle._write_code()
|
||||
assert len(code) > 0
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.asyncio
|
||||
async def test_ml_engineer():
|
||||
data_path = "tests/data/ml_datasets/titanic"
|
||||
requirement = f"This is a titanic passenger survival dataset, your goal is to predict passenger survival outcome. The target column is Survived. Perform data analysis, data preprocessing, feature engineering, and modeling to predict the target. Report accuracy on the eval data. Train data path: '{data_path}/split_train.csv', eval data path: '{data_path}/split_eval.csv'."
|
||||
tools = ["FillMissingValue", "CatCross", "dummy_tool"]
|
||||
|
||||
mle = MLEngineer(auto_run=True, use_tools=True, tools=tools)
|
||||
rsp = await mle.run(requirement)
|
||||
logger.info(rsp)
|
||||
assert len(rsp.content) > 0
|
||||
34
tests/metagpt/roles/di/test_data_interpreter.py
Normal file
34
tests/metagpt/roles/di/test_data_interpreter.py
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.logs import logger
|
||||
from metagpt.roles.di.data_interpreter import DataInterpreter
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("auto_run", [(True), (False)])
|
||||
async def test_interpreter(mocker, auto_run):
|
||||
mocker.patch("metagpt.actions.di.execute_nb_code.ExecuteNbCode.run", return_value=("a successful run", True))
|
||||
mocker.patch("builtins.input", return_value="confirm")
|
||||
|
||||
requirement = "Run data analysis on sklearn Wine recognition dataset, include a plot, and train a model to predict wine class (20% as validation), and show validation accuracy."
|
||||
|
||||
di = DataInterpreter(auto_run=auto_run)
|
||||
rsp = await di.run(requirement)
|
||||
logger.info(rsp)
|
||||
assert len(rsp.content) > 0
|
||||
|
||||
finished_tasks = di.planner.plan.get_finished_tasks()
|
||||
assert len(finished_tasks) > 0
|
||||
assert len(finished_tasks[0].code) > 0 # check one task to see if code is recorded
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_interpreter_react_mode(mocker):
|
||||
mocker.patch("metagpt.actions.di.execute_nb_code.ExecuteNbCode.run", return_value=("a successful run", True))
|
||||
|
||||
requirement = "Run data analysis on sklearn Wine recognition dataset, include a plot, and train a model to predict wine class (20% as validation), and show validation accuracy."
|
||||
|
||||
di = DataInterpreter(react_mode="react")
|
||||
rsp = await di.run(requirement)
|
||||
logger.info(rsp)
|
||||
assert len(rsp.content) > 0
|
||||
|
|
@ -30,6 +30,17 @@ async def test_run(mocker, context):
|
|||
language: str
|
||||
agent_description: str
|
||||
cause_by: str
|
||||
agent_skills: list
|
||||
|
||||
agent_skills = [
|
||||
{"id": 1, "name": "text_to_speech", "type": "builtin", "config": {}, "enabled": True},
|
||||
{"id": 2, "name": "text_to_image", "type": "builtin", "config": {}, "enabled": True},
|
||||
{"id": 3, "name": "ai_call", "type": "builtin", "config": {}, "enabled": True},
|
||||
{"id": 3, "name": "data_analysis", "type": "builtin", "config": {}, "enabled": True},
|
||||
{"id": 5, "name": "crawler", "type": "builtin", "config": {"engine": "ddg"}, "enabled": True},
|
||||
{"id": 6, "name": "knowledge", "type": "builtin", "config": {}, "enabled": True},
|
||||
{"id": 6, "name": "web_search", "type": "builtin", "config": {}, "enabled": True},
|
||||
]
|
||||
|
||||
inputs = [
|
||||
{
|
||||
|
|
@ -48,6 +59,7 @@ async def test_run(mocker, context):
|
|||
"language": "English",
|
||||
"agent_description": "chatterbox",
|
||||
"cause_by": any_to_str(TalkAction),
|
||||
"agent_skills": [],
|
||||
},
|
||||
{
|
||||
"memory": {
|
||||
|
|
@ -65,24 +77,16 @@ async def test_run(mocker, context):
|
|||
"language": "English",
|
||||
"agent_description": "painter",
|
||||
"cause_by": any_to_str(SkillAction),
|
||||
"agent_skills": agent_skills,
|
||||
},
|
||||
]
|
||||
agent_skills = [
|
||||
{"id": 1, "name": "text_to_speech", "type": "builtin", "config": {}, "enabled": True},
|
||||
{"id": 2, "name": "text_to_image", "type": "builtin", "config": {}, "enabled": True},
|
||||
{"id": 3, "name": "ai_call", "type": "builtin", "config": {}, "enabled": True},
|
||||
{"id": 3, "name": "data_analysis", "type": "builtin", "config": {}, "enabled": True},
|
||||
{"id": 5, "name": "crawler", "type": "builtin", "config": {"engine": "ddg"}, "enabled": True},
|
||||
{"id": 6, "name": "knowledge", "type": "builtin", "config": {}, "enabled": True},
|
||||
{"id": 6, "name": "web_search", "type": "builtin", "config": {}, "enabled": True},
|
||||
]
|
||||
|
||||
for i in inputs:
|
||||
seed = Input(**i)
|
||||
role = Assistant(language="Chinese", context=context)
|
||||
role.context.kwargs.language = seed.language
|
||||
role.context.kwargs.agent_description = seed.agent_description
|
||||
role.context.kwargs.agent_skills = agent_skills
|
||||
role.context.kwargs.agent_skills = seed.agent_skills
|
||||
|
||||
role.memory = seed.memory # Restore historical conversation content.
|
||||
while True:
|
||||
|
|
|
|||
|
|
@ -6,11 +6,11 @@
|
|||
@File : test_tutorial_assistant.py
|
||||
"""
|
||||
|
||||
import aiofiles
|
||||
import pytest
|
||||
|
||||
from metagpt.const import TUTORIAL_PATH
|
||||
from metagpt.roles.tutorial_assistant import TutorialAssistant
|
||||
from metagpt.utils.common import aread
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -20,9 +20,8 @@ async def test_tutorial_assistant(language: str, topic: str, context):
|
|||
msg = await role.run(topic)
|
||||
assert TUTORIAL_PATH.exists()
|
||||
filename = msg.content
|
||||
async with aiofiles.open(filename, mode="r", encoding="utf-8") as reader:
|
||||
content = await reader.read()
|
||||
assert "pip" in content
|
||||
content = await aread(filename=filename)
|
||||
assert "pip" in content
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
37
tests/metagpt/strategy/test_planner.py
Normal file
37
tests/metagpt/strategy/test_planner.py
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
from metagpt.schema import Plan, Task
|
||||
from metagpt.strategy.planner import Planner
|
||||
from metagpt.strategy.task_type import TaskType
|
||||
|
||||
MOCK_TASK_MAP = {
|
||||
"1": Task(
|
||||
task_id="1",
|
||||
instruction="test instruction for finished task",
|
||||
task_type=TaskType.EDA.type_name,
|
||||
dependent_task_ids=[],
|
||||
code="some finished test code",
|
||||
result="some finished test result",
|
||||
is_finished=True,
|
||||
),
|
||||
"2": Task(
|
||||
task_id="2",
|
||||
instruction="test instruction for current task",
|
||||
task_type=TaskType.DATA_PREPROCESS.type_name,
|
||||
dependent_task_ids=["1"],
|
||||
),
|
||||
}
|
||||
MOCK_PLAN = Plan(
|
||||
goal="test goal",
|
||||
tasks=list(MOCK_TASK_MAP.values()),
|
||||
task_map=MOCK_TASK_MAP,
|
||||
current_task_id="2",
|
||||
)
|
||||
|
||||
|
||||
def test_planner_get_plan_status():
|
||||
planner = Planner(plan=MOCK_PLAN)
|
||||
status = planner.get_plan_status()
|
||||
|
||||
assert "some finished test code" in status
|
||||
assert "some finished test result" in status
|
||||
assert "test instruction for current task" in status
|
||||
assert TaskType.DATA_PREPROCESS.value.guidance in status # current task guidance
|
||||
|
|
@ -1,9 +1,11 @@
|
|||
from pathlib import Path
|
||||
from pprint import pformat
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.const import METAGPT_ROOT
|
||||
from metagpt.logs import logger
|
||||
from metagpt.repo_parser import RepoParser
|
||||
from metagpt.repo_parser import DotClassAttribute, DotClassMethod, DotReturn, RepoParser
|
||||
|
||||
|
||||
def test_repo_parser():
|
||||
|
|
@ -23,3 +25,140 @@ def test_error():
|
|||
"""_parse_file should return empty list when file not existed"""
|
||||
rsp = RepoParser._parse_file(Path("test_not_existed_file.py"))
|
||||
assert rsp == []
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("v", "name", "type_", "default_", "compositions"),
|
||||
[
|
||||
("children : dict[str, 'ActionNode']", "children", "dict[str,ActionNode]", "", ["ActionNode"]),
|
||||
("context : str", "context", "str", "", []),
|
||||
("example", "example", "", "", []),
|
||||
("expected_type : Type", "expected_type", "Type", "", ["Type"]),
|
||||
("args : Optional[Dict]", "args", "Optional[Dict]", "", []),
|
||||
("rsp : Optional[Message] = Message.Default", "rsp", "Optional[Message]", "Message.Default", ["Message"]),
|
||||
(
|
||||
"browser : Literal['chrome', 'firefox', 'edge', 'ie']",
|
||||
"browser",
|
||||
"Literal['chrome','firefox','edge','ie']",
|
||||
"",
|
||||
[],
|
||||
),
|
||||
(
|
||||
"browser : Dict[ Message, Literal['chrome', 'firefox', 'edge', 'ie'] ]",
|
||||
"browser",
|
||||
"Dict[Message,Literal['chrome','firefox','edge','ie']]",
|
||||
"",
|
||||
["Message"],
|
||||
),
|
||||
("attributes : List[ClassAttribute]", "attributes", "List[ClassAttribute]", "", ["ClassAttribute"]),
|
||||
("attributes = []", "attributes", "", "[]", []),
|
||||
(
|
||||
"request_timeout: Optional[Union[float, Tuple[float, float]]]",
|
||||
"request_timeout",
|
||||
"Optional[Union[float,Tuple[float,float]]]",
|
||||
"",
|
||||
[],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_parse_member(v, name, type_, default_, compositions):
|
||||
attr = DotClassAttribute.parse(v)
|
||||
assert name == attr.name
|
||||
assert type_ == attr.type_
|
||||
assert default_ == attr.default_
|
||||
assert compositions == attr.compositions
|
||||
assert v == attr.description
|
||||
|
||||
json_data = attr.model_dump_json()
|
||||
v = DotClassAttribute.model_validate_json(json_data)
|
||||
assert v == attr
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("line", "package_name", "info"),
|
||||
[
|
||||
(
|
||||
'"metagpt.roles.architect.Architect" [color="black", fontcolor="black", label=<{Architect|constraints : str<br ALIGN="LEFT"/>goal : str<br ALIGN="LEFT"/>name : str<br ALIGN="LEFT"/>profile : str<br ALIGN="LEFT"/>|}>, shape="record", style="solid"];',
|
||||
"metagpt.roles.architect.Architect",
|
||||
"Architect|constraints : str\ngoal : str\nname : str\nprofile : str\n|",
|
||||
),
|
||||
(
|
||||
'"metagpt.actions.skill_action.ArgumentsParingAction" [color="black", fontcolor="black", label=<{ArgumentsParingAction|args : Optional[Dict]<br ALIGN="LEFT"/>ask : str<br ALIGN="LEFT"/>prompt<br ALIGN="LEFT"/>rsp : Optional[Message]<br ALIGN="LEFT"/>skill<br ALIGN="LEFT"/>|parse_arguments(skill_name, txt): dict<br ALIGN="LEFT"/>run(with_message): Message<br ALIGN="LEFT"/>}>, shape="record", style="solid"];',
|
||||
"metagpt.actions.skill_action.ArgumentsParingAction",
|
||||
"ArgumentsParingAction|args : Optional[Dict]\nask : str\nprompt\nrsp : Optional[Message]\nskill\n|parse_arguments(skill_name, txt): dict\nrun(with_message): Message\n",
|
||||
),
|
||||
(
|
||||
'"metagpt.strategy.base.BaseEvaluator" [color="black", fontcolor="black", label=<{BaseEvaluator|<br ALIGN="LEFT"/>|<I>status_verify</I>()<br ALIGN="LEFT"/>}>, shape="record", style="solid"];',
|
||||
"metagpt.strategy.base.BaseEvaluator",
|
||||
"BaseEvaluator|\n|<I>status_verify</I>()\n",
|
||||
),
|
||||
(
|
||||
'"metagpt.configs.browser_config.BrowserConfig" [color="black", fontcolor="black", label=<{BrowserConfig|browser : Literal[\'chrome\', \'firefox\', \'edge\', \'ie\']<br ALIGN="LEFT"/>driver : Literal[\'chromium\', \'firefox\', \'webkit\']<br ALIGN="LEFT"/>engine<br ALIGN="LEFT"/>path : str<br ALIGN="LEFT"/>|}>, shape="record", style="solid"];',
|
||||
"metagpt.configs.browser_config.BrowserConfig",
|
||||
"BrowserConfig|browser : Literal['chrome', 'firefox', 'edge', 'ie']\ndriver : Literal['chromium', 'firefox', 'webkit']\nengine\npath : str\n|",
|
||||
),
|
||||
(
|
||||
'"metagpt.tools.search_engine_serpapi.SerpAPIWrapper" [color="black", fontcolor="black", label=<{SerpAPIWrapper|aiosession : Optional[aiohttp.ClientSession]<br ALIGN="LEFT"/>model_config<br ALIGN="LEFT"/>params : dict<br ALIGN="LEFT"/>search_engine : Optional[Any]<br ALIGN="LEFT"/>serpapi_api_key : Optional[str]<br ALIGN="LEFT"/>|check_serpapi_api_key(val: str)<br ALIGN="LEFT"/>get_params(query: str): Dict[str, str]<br ALIGN="LEFT"/>results(query: str, max_results: int): dict<br ALIGN="LEFT"/>run(query, max_results: int, as_string: bool): str<br ALIGN="LEFT"/>}>, shape="record", style="solid"];',
|
||||
"metagpt.tools.search_engine_serpapi.SerpAPIWrapper",
|
||||
"SerpAPIWrapper|aiosession : Optional[aiohttp.ClientSession]\nmodel_config\nparams : dict\nsearch_engine : Optional[Any]\nserpapi_api_key : Optional[str]\n|check_serpapi_api_key(val: str)\nget_params(query: str): Dict[str, str]\nresults(query: str, max_results: int): dict\nrun(query, max_results: int, as_string: bool): str\n",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_split_class_line(line, package_name, info):
|
||||
p, i = RepoParser._split_class_line(line)
|
||||
assert p == package_name
|
||||
assert i == info
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("v", "name", "args", "return_args"),
|
||||
[
|
||||
(
|
||||
"<I>arequest</I>(method, url, params, headers, files, stream: Literal[True], request_id: Optional[str], request_timeout: Optional[Union[float, Tuple[float, float]]]): Tuple[AsyncGenerator[OpenAIResponse, None], bool, str]",
|
||||
"arequest",
|
||||
[
|
||||
DotClassAttribute(name="method", description="method"),
|
||||
DotClassAttribute(name="url", description="url"),
|
||||
DotClassAttribute(name="params", description="params"),
|
||||
DotClassAttribute(name="headers", description="headers"),
|
||||
DotClassAttribute(name="files", description="files"),
|
||||
DotClassAttribute(name="stream", type_="Literal[True]", description="stream: Literal[True]"),
|
||||
DotClassAttribute(name="request_id", type_="Optional[str]", description="request_id: Optional[str]"),
|
||||
DotClassAttribute(
|
||||
name="request_timeout",
|
||||
type_="Optional[Union[float,Tuple[float,float]]]",
|
||||
description="request_timeout: Optional[Union[float, Tuple[float, float]]]",
|
||||
),
|
||||
],
|
||||
DotReturn(
|
||||
type_="Tuple[AsyncGenerator[OpenAIResponse,None],bool,str]",
|
||||
compositions=["AsyncGenerator", "OpenAIResponse"],
|
||||
description="Tuple[AsyncGenerator[OpenAIResponse, None], bool, str]",
|
||||
),
|
||||
),
|
||||
(
|
||||
"<I>update</I>(subject: str, predicate: str, object_: str)",
|
||||
"update",
|
||||
[
|
||||
DotClassAttribute(name="subject", type_="str", description="subject: str"),
|
||||
DotClassAttribute(name="predicate", type_="str", description="predicate: str"),
|
||||
DotClassAttribute(name="object_", type_="str", description="object_: str"),
|
||||
],
|
||||
DotReturn(description=""),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_parse_method(v, name, args, return_args):
|
||||
method = DotClassMethod.parse(v)
|
||||
assert method.name == name
|
||||
assert method.args == args
|
||||
assert method.return_args == return_args
|
||||
assert method.description == v
|
||||
|
||||
json_data = method.model_dump_json()
|
||||
v = DotClassMethod.model_validate_json(json_data)
|
||||
assert v == method
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
|
|
@ -18,9 +18,6 @@ from metagpt.actions.write_code import WriteCode
|
|||
from metagpt.const import SYSTEM_DESIGN_FILE_REPO, TASK_FILE_REPO
|
||||
from metagpt.schema import (
|
||||
AIMessage,
|
||||
ClassAttribute,
|
||||
ClassMethod,
|
||||
ClassView,
|
||||
CodeSummarizeContext,
|
||||
Document,
|
||||
Message,
|
||||
|
|
@ -28,6 +25,9 @@ from metagpt.schema import (
|
|||
Plan,
|
||||
SystemMessage,
|
||||
Task,
|
||||
UMLClassAttribute,
|
||||
UMLClassMethod,
|
||||
UMLClassView,
|
||||
UserMessage,
|
||||
)
|
||||
from metagpt.utils.common import any_to_str
|
||||
|
|
@ -159,27 +159,26 @@ def test_CodeSummarizeContext(file_list, want):
|
|||
|
||||
|
||||
def test_class_view():
|
||||
attr_a = ClassAttribute(name="a", value_type="int", default_value="0", visibility="+", abstraction=True)
|
||||
assert attr_a.get_mermaid(align=1) == "\t+int a=0*"
|
||||
attr_b = ClassAttribute(name="b", value_type="str", default_value="0", visibility="#", static=True)
|
||||
assert attr_b.get_mermaid(align=0) == '#str b="0"$'
|
||||
class_view = ClassView(name="A")
|
||||
attr_a = UMLClassAttribute(name="a", value_type="int", default_value="0", visibility="+")
|
||||
assert attr_a.get_mermaid(align=1) == "\t+int a=0"
|
||||
attr_b = UMLClassAttribute(name="b", value_type="str", default_value="0", visibility="#")
|
||||
assert attr_b.get_mermaid(align=0) == '#str b="0"'
|
||||
class_view = UMLClassView(name="A")
|
||||
class_view.attributes = [attr_a, attr_b]
|
||||
|
||||
method_a = ClassMethod(name="run", visibility="+", abstraction=True)
|
||||
assert method_a.get_mermaid(align=1) == "\t+run()*"
|
||||
method_b = ClassMethod(
|
||||
method_a = UMLClassMethod(name="run", visibility="+")
|
||||
assert method_a.get_mermaid(align=1) == "\t+run()"
|
||||
method_b = UMLClassMethod(
|
||||
name="_test",
|
||||
visibility="#",
|
||||
static=True,
|
||||
args=[ClassAttribute(name="a", value_type="str"), ClassAttribute(name="b", value_type="int")],
|
||||
args=[UMLClassAttribute(name="a", value_type="str"), UMLClassAttribute(name="b", value_type="int")],
|
||||
return_type="str",
|
||||
)
|
||||
assert method_b.get_mermaid(align=0) == "#_test(str a,int b):str$"
|
||||
assert method_b.get_mermaid(align=0) == "#_test(str a,int b) str"
|
||||
class_view.methods = [method_a, method_b]
|
||||
assert (
|
||||
class_view.get_mermaid(align=0)
|
||||
== 'class A{\n\t+int a=0*\n\t#str b="0"$\n\t+run()*\n\t#_test(str a,int b):str$\n}\n'
|
||||
== 'class A{\n\t+int a=0\n\t#str b="0"\n\t+run()\n\t#_test(str a,int b) str\n}\n'
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
7
tests/metagpt/tools/libs/test_email_login.py
Normal file
7
tests/metagpt/tools/libs/test_email_login.py
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
from metagpt.tools.libs.email_login import email_login_imap
|
||||
|
||||
|
||||
def test_email_login(mocker):
|
||||
mock_mailbox = mocker.patch("metagpt.tools.libs.email_login.MailBox.login")
|
||||
mock_mailbox.login.return_value = mocker.Mock()
|
||||
email_login_imap("test@outlook.com", "test_password")
|
||||
|
|
@ -60,18 +60,24 @@ async def test_generate_webpages(mock_webpage_filename_with_styles_and_scripts,
|
|||
async def test_save_webpages_with_styles_and_scripts(mock_webpage_filename_with_styles_and_scripts, image_path):
|
||||
generator = GPTvGenerator()
|
||||
webpages = await generator.generate_webpages(image_path)
|
||||
webpages_dir = generator.save_webpages(image_path=image_path, webpages=webpages)
|
||||
webpages_dir = generator.save_webpages(webpages=webpages, save_folder_name="test_1")
|
||||
logs.logger.info(webpages_dir)
|
||||
assert webpages_dir.exists()
|
||||
assert (webpages_dir / "index.html").exists()
|
||||
assert (webpages_dir / "styles.css").exists()
|
||||
assert (webpages_dir / "scripts.js").exists()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_webpages_with_style_and_script(mock_webpage_filename_with_style_and_script, image_path):
|
||||
generator = GPTvGenerator()
|
||||
webpages = await generator.generate_webpages(image_path)
|
||||
webpages_dir = generator.save_webpages(image_path=image_path, webpages=webpages)
|
||||
webpages_dir = generator.save_webpages(webpages=webpages, save_folder_name="test_2")
|
||||
logs.logger.info(webpages_dir)
|
||||
assert webpages_dir.exists()
|
||||
assert (webpages_dir / "index.html").exists()
|
||||
assert (webpages_dir / "style.css").exists()
|
||||
assert (webpages_dir / "script.js").exists()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
|
|||
|
|
@ -4,8 +4,8 @@ from metagpt.tools.libs.web_scraping import scrape_web_playwright
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scrape_web_playwright():
|
||||
test_url = "https://www.deepwisdom.ai"
|
||||
async def test_scrape_web_playwright(http_server):
|
||||
server, test_url = await http_server()
|
||||
|
||||
result = await scrape_web_playwright(test_url)
|
||||
|
||||
|
|
@ -21,3 +21,4 @@ async def test_scrape_web_playwright():
|
|||
assert not result["inner_text"].endswith(" ")
|
||||
assert not result["html"].startswith(" ")
|
||||
assert not result["html"].endswith(" ")
|
||||
await server.stop()
|
||||
|
|
|
|||
|
|
@ -11,7 +11,6 @@ from typing import Callable
|
|||
|
||||
import pytest
|
||||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.configs.search_config import SearchConfig
|
||||
from metagpt.logs import logger
|
||||
from metagpt.tools import SearchEngineType
|
||||
|
|
@ -53,14 +52,11 @@ async def test_search_engine(
|
|||
search_engine_config = {"engine": search_engine_type, "run_func": run_func}
|
||||
|
||||
if search_engine_type is SearchEngineType.SERPAPI_GOOGLE:
|
||||
assert config.search
|
||||
search_engine_config["api_key"] = "mock-serpapi-key"
|
||||
elif search_engine_type is SearchEngineType.DIRECT_GOOGLE:
|
||||
assert config.search
|
||||
search_engine_config["api_key"] = "mock-google-key"
|
||||
search_engine_config["cse_id"] = "mock-google-cse"
|
||||
elif search_engine_type is SearchEngineType.SERPER_GOOGLE:
|
||||
assert config.search
|
||||
search_engine_config["api_key"] = "mock-serper-key"
|
||||
|
||||
async def test(search_engine):
|
||||
|
|
|
|||
|
|
@ -1,44 +1,8 @@
|
|||
from typing import Literal, Union
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from metagpt.tools.tool_convert import convert_code_to_tool_schema, docstring_to_schema
|
||||
|
||||
|
||||
def test_docstring_to_schema():
|
||||
docstring = """
|
||||
Some test desc.
|
||||
|
||||
Args:
|
||||
features (list): Columns to be processed.
|
||||
strategy (str, optional): The imputation strategy, notice 'mean' and 'median' can only be
|
||||
used for numeric features. Enum: ['mean', 'median', 'most_frequent', 'constant']. Defaults to 'mean'.
|
||||
fill_value (int, optional): Fill_value is used to replace all occurrences of missing_values.
|
||||
Defaults to None.
|
||||
Returns:
|
||||
pd.DataFrame: The transformed DataFrame.
|
||||
"""
|
||||
expected = {
|
||||
"description": "Some test desc.",
|
||||
"parameters": {
|
||||
"properties": {
|
||||
"features": {"type": "list", "description": "Columns to be processed."},
|
||||
"strategy": {
|
||||
"type": "str",
|
||||
"description": "The imputation strategy, notice 'mean' and 'median' can only be used for numeric features. Enum: ['mean', 'median', 'most_frequent', 'constant']. Defaults to 'mean'.",
|
||||
"default": "'mean'",
|
||||
"enum": ["'mean'", "'median'", "'most_frequent'", "'constant'"],
|
||||
},
|
||||
"fill_value": {
|
||||
"type": "int",
|
||||
"description": "Fill_value is used to replace all occurrences of missing_values. Defaults to None.",
|
||||
"default": "None",
|
||||
},
|
||||
},
|
||||
"required": ["features"],
|
||||
},
|
||||
"returns": [{"type": "pd.DataFrame", "description": "The transformed DataFrame."}],
|
||||
}
|
||||
schema = docstring_to_schema(docstring)
|
||||
assert schema == expected
|
||||
from metagpt.tools.tool_convert import convert_code_to_tool_schema
|
||||
|
||||
|
||||
class DummyClass:
|
||||
|
|
@ -81,12 +45,25 @@ class DummyClass:
|
|||
pass
|
||||
|
||||
|
||||
def dummy_fn(df: pd.DataFrame) -> dict:
|
||||
def dummy_fn(
|
||||
df: pd.DataFrame,
|
||||
s: str,
|
||||
k: int = 5,
|
||||
type: Literal["a", "b", "c"] = "a",
|
||||
test_dict: dict[str, int] = None,
|
||||
test_union: Union[str, list[str]] = "",
|
||||
) -> dict:
|
||||
"""
|
||||
Analyzes a DataFrame and categorizes its columns based on data types.
|
||||
|
||||
Args:
|
||||
df (pd.DataFrame): The DataFrame to be analyzed.
|
||||
df: The DataFrame to be analyzed.
|
||||
Another line for df.
|
||||
s (str): Some test string param.
|
||||
Another line for s.
|
||||
k (int, optional): Some test integer param. Defaults to 5.
|
||||
type (Literal["a", "b", "c"], optional): Some test type. Defaults to 'a'.
|
||||
more_args: will be omitted here for testing
|
||||
|
||||
Returns:
|
||||
dict: A dictionary with four keys ('Category', 'Numeric', 'Datetime', 'Others').
|
||||
|
|
@ -115,41 +92,21 @@ def test_convert_code_to_tool_schema_class():
|
|||
"methods": {
|
||||
"__init__": {
|
||||
"type": "function",
|
||||
"description": "Initialize self.",
|
||||
"parameters": {
|
||||
"properties": {
|
||||
"features": {"type": "list", "description": "Columns to be processed."},
|
||||
"strategy": {
|
||||
"type": "str",
|
||||
"description": "The imputation strategy, notice 'mean' and 'median' can only be used for numeric features. Enum: ['mean', 'median', 'most_frequent', 'constant']. Defaults to 'mean'.",
|
||||
"default": "'mean'",
|
||||
"enum": ["'mean'", "'median'", "'most_frequent'", "'constant'"],
|
||||
},
|
||||
"fill_value": {
|
||||
"type": "int",
|
||||
"description": "Fill_value is used to replace all occurrences of missing_values. Defaults to None.",
|
||||
"default": "None",
|
||||
},
|
||||
},
|
||||
"required": ["features"],
|
||||
},
|
||||
"description": "Initialize self. ",
|
||||
"signature": "(self, features: list, strategy: str = 'mean', fill_value=None)",
|
||||
"parameters": "Args: features (list): Columns to be processed. strategy (str, optional): The imputation strategy, notice 'mean' and 'median' can only be used for numeric features. Enum: ['mean', 'median', 'most_frequent', 'constant']. Defaults to 'mean'. fill_value (int, optional): Fill_value is used to replace all occurrences of missing_values. Defaults to None.",
|
||||
},
|
||||
"fit": {
|
||||
"type": "function",
|
||||
"description": "Fit the FillMissingValue model.",
|
||||
"parameters": {
|
||||
"properties": {"df": {"type": "pd.DataFrame", "description": "The input DataFrame."}},
|
||||
"required": ["df"],
|
||||
},
|
||||
"description": "Fit the FillMissingValue model. ",
|
||||
"signature": "(self, df: pandas.core.frame.DataFrame)",
|
||||
"parameters": "Args: df (pd.DataFrame): The input DataFrame.",
|
||||
},
|
||||
"transform": {
|
||||
"type": "function",
|
||||
"description": "Transform the input DataFrame with the fitted model.",
|
||||
"parameters": {
|
||||
"properties": {"df": {"type": "pd.DataFrame", "description": "The input DataFrame."}},
|
||||
"required": ["df"],
|
||||
},
|
||||
"returns": [{"type": "pd.DataFrame", "description": "The transformed DataFrame."}],
|
||||
"description": "Transform the input DataFrame with the fitted model. ",
|
||||
"signature": "(self, df: pandas.core.frame.DataFrame) -> pandas.core.frame.DataFrame",
|
||||
"parameters": "Args: df (pd.DataFrame): The input DataFrame. Returns: pd.DataFrame: The transformed DataFrame.",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
@ -160,11 +117,9 @@ def test_convert_code_to_tool_schema_class():
|
|||
def test_convert_code_to_tool_schema_function():
|
||||
expected = {
|
||||
"type": "function",
|
||||
"description": "Analyzes a DataFrame and categorizes its columns based on data types.",
|
||||
"parameters": {
|
||||
"properties": {"df": {"type": "pd.DataFrame", "description": "The DataFrame to be analyzed."}},
|
||||
"required": ["df"],
|
||||
},
|
||||
"description": "Analyzes a DataFrame and categorizes its columns based on data types. ",
|
||||
"signature": "(df: pandas.core.frame.DataFrame, s: str, k: int = 5, type: Literal['a', 'b', 'c'] = 'a', test_dict: dict[str, int] = None, test_union: Union[str, list[str]] = '') -> dict",
|
||||
"parameters": "Args: df: The DataFrame to be analyzed. Another line for df. s (str): Some test string param. Another line for s. k (int, optional): Some test integer param. Defaults to 5. type (Literal[\"a\", \"b\", \"c\"], optional): Some test type. Defaults to 'a'. more_args: will be omitted here for testing Returns: dict: A dictionary with four keys ('Category', 'Numeric', 'Datetime', 'Others'). Each key corresponds to a list of column names belonging to that category.",
|
||||
}
|
||||
schema = convert_code_to_tool_schema(dummy_fn)
|
||||
assert schema == expected
|
||||
|
|
|
|||
90
tests/metagpt/tools/test_tool_recommend.py
Normal file
90
tests/metagpt/tools/test_tool_recommend.py
Normal file
|
|
@ -0,0 +1,90 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.schema import Plan, Task
|
||||
from metagpt.tools import TOOL_REGISTRY
|
||||
from metagpt.tools.tool_recommend import (
|
||||
BM25ToolRecommender,
|
||||
ToolRecommender,
|
||||
TypeMatchToolRecommender,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_plan(mocker):
|
||||
task_map = {
|
||||
"1": Task(
|
||||
task_id="1",
|
||||
instruction="conduct feature engineering, add new features on the dataset",
|
||||
task_type="feature engineering",
|
||||
)
|
||||
}
|
||||
plan = Plan(
|
||||
goal="test requirement",
|
||||
tasks=list(task_map.values()),
|
||||
task_map=task_map,
|
||||
current_task_id="1",
|
||||
)
|
||||
return plan
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_bm25_tr(mocker):
|
||||
tr = BM25ToolRecommender(tools=["FillMissingValue", "PolynomialExpansion", "web scraping"])
|
||||
return tr
|
||||
|
||||
|
||||
def test_tr_init():
|
||||
tr = ToolRecommender(tools=["FillMissingValue", "PolynomialExpansion", "web scraping", "non-existing tool"])
|
||||
# web_scraping is a tool tag, it has one tool scrape_web_playwright
|
||||
assert list(tr.tools.keys()) == [
|
||||
"FillMissingValue",
|
||||
"PolynomialExpansion",
|
||||
"scrape_web_playwright",
|
||||
]
|
||||
|
||||
|
||||
def test_tr_init_default_tools_value():
|
||||
tr = ToolRecommender()
|
||||
assert tr.tools == {}
|
||||
|
||||
|
||||
def test_tr_init_tools_all():
|
||||
tr = ToolRecommender(tools=["<all>"])
|
||||
assert list(tr.tools.keys()) == list(TOOL_REGISTRY.get_all_tools().keys())
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bm25_tr_recall_with_plan(mock_plan, mock_bm25_tr):
|
||||
result = await mock_bm25_tr.recall_tools(plan=mock_plan)
|
||||
assert len(result) == 3
|
||||
assert result[0].name == "PolynomialExpansion"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bm25_tr_recall_no_plan(mock_plan, mock_bm25_tr):
|
||||
result = await mock_bm25_tr.recall_tools(
|
||||
context="conduct feature engineering, add new features on the dataset", plan=None
|
||||
)
|
||||
assert len(result) == 3
|
||||
assert result[0].name == "PolynomialExpansion"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bm25_recommend_tools(mock_bm25_tr):
|
||||
result = await mock_bm25_tr.recommend_tools(context="conduct feature engineering, add new features on the dataset")
|
||||
assert len(result) == 2 # web scraping tool should be filtered out at rank stage
|
||||
assert result[0].name == "PolynomialExpansion"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_recommended_tool_info(mock_plan, mock_bm25_tr):
|
||||
result = await mock_bm25_tr.get_recommended_tool_info(plan=mock_plan)
|
||||
assert isinstance(result, str)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tm_tr_recall_with_plan(mock_plan, mock_bm25_tr):
|
||||
tr = TypeMatchToolRecommender(tools=["FillMissingValue", "PolynomialExpansion", "web scraping"])
|
||||
result = await tr.recall_tools(plan=mock_plan)
|
||||
assert len(result) == 1
|
||||
assert result[0].name == "PolynomialExpansion"
|
||||
|
|
@ -1,7 +1,6 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.tools.tool_registry import ToolRegistry
|
||||
from metagpt.tools.tool_type import ToolType
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -9,25 +8,11 @@ def tool_registry():
|
|||
return ToolRegistry()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tool_registry_full():
|
||||
return ToolRegistry(tool_types=ToolType)
|
||||
|
||||
|
||||
# Test Initialization
|
||||
def test_initialization(tool_registry):
|
||||
assert isinstance(tool_registry, ToolRegistry)
|
||||
assert tool_registry.tools == {}
|
||||
assert tool_registry.tool_types == {}
|
||||
assert tool_registry.tools_by_types == {}
|
||||
|
||||
|
||||
# Test Initialization with tool types
|
||||
def test_initialize_with_tool_types(tool_registry_full):
|
||||
assert isinstance(tool_registry_full, ToolRegistry)
|
||||
assert tool_registry_full.tools == {}
|
||||
assert tool_registry_full.tools_by_types == {}
|
||||
assert "data_preprocess" in tool_registry_full.tool_types
|
||||
assert tool_registry.tools_by_tags == {}
|
||||
|
||||
|
||||
class TestClassTool:
|
||||
|
|
@ -72,31 +57,24 @@ def test_get_tool(tool_registry):
|
|||
assert "description" in tool.schemas
|
||||
|
||||
|
||||
# Similar tests for has_tool_type, get_tool_type, get_tools_by_type
|
||||
def test_has_tool_type(tool_registry_full):
|
||||
assert tool_registry_full.has_tool_type("data_preprocess")
|
||||
assert not tool_registry_full.has_tool_type("NonexistentType")
|
||||
def test_has_tool_tag(tool_registry):
|
||||
tool_registry.register_tool(
|
||||
"TestClassTool", "/path/to/tool", tool_source_object=TestClassTool, tags=["machine learning", "test"]
|
||||
)
|
||||
assert tool_registry.has_tool_tag("test")
|
||||
assert not tool_registry.has_tool_tag("Non-existent tag")
|
||||
|
||||
|
||||
def test_get_tool_type(tool_registry_full):
|
||||
retrieved_type = tool_registry_full.get_tool_type("data_preprocess")
|
||||
assert retrieved_type is not None
|
||||
assert retrieved_type.name == "data_preprocess"
|
||||
|
||||
|
||||
def test_get_tools_by_type(tool_registry):
|
||||
tool_type_name = "TestType"
|
||||
def test_get_tools_by_tag(tool_registry):
|
||||
tool_tag_name = "Test Tag"
|
||||
tool_name = "TestTool"
|
||||
tool_path = "/path/to/tool"
|
||||
|
||||
tool_registry.register_tool(tool_name, tool_path, tool_type=tool_type_name, tool_source_object=TestClassTool)
|
||||
tool_registry.register_tool(tool_name, tool_path, tags=[tool_tag_name], tool_source_object=TestClassTool)
|
||||
|
||||
tools_by_type = tool_registry.get_tools_by_type(tool_type_name)
|
||||
assert tools_by_type is not None
|
||||
assert tool_name in tools_by_type
|
||||
tools_by_tag = tool_registry.get_tools_by_tag(tool_tag_name)
|
||||
assert tools_by_tag is not None
|
||||
assert tool_name in tools_by_tag
|
||||
|
||||
|
||||
# Test case for when the tool type does not exist
|
||||
def test_get_tools_by_nonexistent_type(tool_registry):
|
||||
tools_by_type = tool_registry.get_tools_by_type("NonexistentType")
|
||||
assert not tools_by_type
|
||||
tools_by_tag_non_existent = tool_registry.get_tools_by_tag("Non-existent Tag")
|
||||
assert not tools_by_tag_non_existent
|
||||
|
|
|
|||
|
|
@ -9,14 +9,16 @@ from metagpt.utils.parse_html import WebPage
|
|||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"browser_type, url, urls",
|
||||
"browser_type",
|
||||
[
|
||||
(WebBrowserEngineType.PLAYWRIGHT, "https://deepwisdom.ai", ("https://deepwisdom.ai",)),
|
||||
(WebBrowserEngineType.SELENIUM, "https://deepwisdom.ai", ("https://deepwisdom.ai",)),
|
||||
WebBrowserEngineType.PLAYWRIGHT,
|
||||
WebBrowserEngineType.SELENIUM,
|
||||
],
|
||||
ids=["playwright", "selenium"],
|
||||
)
|
||||
async def test_scrape_web_page(browser_type, url, urls):
|
||||
async def test_scrape_web_page(browser_type, http_server):
|
||||
server, url = await http_server()
|
||||
urls = [url, url, url]
|
||||
browser = web_browser_engine.WebBrowserEngine(engine=browser_type)
|
||||
result = await browser.run(url)
|
||||
assert isinstance(result, WebPage)
|
||||
|
|
@ -27,6 +29,7 @@ async def test_scrape_web_page(browser_type, url, urls):
|
|||
assert isinstance(results, list)
|
||||
assert len(results) == len(urls) + 1
|
||||
assert all(("MetaGPT" in i.inner_text) for i in results)
|
||||
await server.stop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -9,18 +9,28 @@ from metagpt.utils.parse_html import WebPage
|
|||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"browser_type, use_proxy, kwagrs, url, urls",
|
||||
"browser_type, use_proxy, kwagrs,",
|
||||
[
|
||||
("chromium", {"proxy": True}, {}, "https://www.deepwisdom.ai", ("https://www.deepwisdom.ai",)),
|
||||
("firefox", {}, {"ignore_https_errors": True}, "https://www.deepwisdom.ai", ("https://www.deepwisdom.ai",)),
|
||||
("webkit", {}, {"ignore_https_errors": True}, "https://www.deepwisdom.ai", ("https://www.deepwisdom.ai",)),
|
||||
("chromium", {"proxy": True}, {}),
|
||||
(
|
||||
"firefox",
|
||||
{},
|
||||
{"ignore_https_errors": True},
|
||||
),
|
||||
(
|
||||
"webkit",
|
||||
{},
|
||||
{"ignore_https_errors": True},
|
||||
),
|
||||
],
|
||||
ids=["chromium-normal", "firefox-normal", "webkit-normal"],
|
||||
)
|
||||
async def test_scrape_web_page(browser_type, use_proxy, kwagrs, url, urls, proxy, capfd):
|
||||
async def test_scrape_web_page(browser_type, use_proxy, kwagrs, proxy, capfd, http_server):
|
||||
server, url = await http_server()
|
||||
urls = [url, url, url]
|
||||
proxy_url = None
|
||||
if use_proxy:
|
||||
server, proxy_url = await proxy()
|
||||
proxy_server, proxy_url = await proxy()
|
||||
browser = web_browser_engine_playwright.PlaywrightWrapper(browser_type=browser_type, proxy=proxy_url, **kwagrs)
|
||||
result = await browser.run(url)
|
||||
assert isinstance(result, WebPage)
|
||||
|
|
@ -32,8 +42,10 @@ async def test_scrape_web_page(browser_type, use_proxy, kwagrs, url, urls, proxy
|
|||
assert len(results) == len(urls) + 1
|
||||
assert all(("MetaGPT" in i.inner_text) for i in results)
|
||||
if use_proxy:
|
||||
server.close()
|
||||
proxy_server.close()
|
||||
await proxy_server.wait_closed()
|
||||
assert "Proxy:" in capfd.readouterr().out
|
||||
await server.stop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
|
||||
import browsers
|
||||
import pytest
|
||||
|
||||
|
|
@ -10,51 +11,48 @@ from metagpt.utils.parse_html import WebPage
|
|||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"browser_type, use_proxy, url, urls",
|
||||
"browser_type, use_proxy,",
|
||||
[
|
||||
pytest.param(
|
||||
"chrome",
|
||||
True,
|
||||
"https://deepwisdom.ai",
|
||||
("https://deepwisdom.ai",),
|
||||
False,
|
||||
marks=pytest.mark.skipif(not browsers.get("chrome"), reason="chrome browser not found"),
|
||||
),
|
||||
pytest.param(
|
||||
"firefox",
|
||||
False,
|
||||
"https://deepwisdom.ai",
|
||||
("https://deepwisdom.ai",),
|
||||
marks=pytest.mark.skipif(not browsers.get("firefox"), reason="firefox browser not found"),
|
||||
),
|
||||
pytest.param(
|
||||
"edge",
|
||||
False,
|
||||
"https://deepwisdom.ai",
|
||||
("https://deepwisdom.ai",),
|
||||
marks=pytest.mark.skipif(not browsers.get("msedge"), reason="edge browser not found"),
|
||||
),
|
||||
],
|
||||
ids=["chrome-normal", "firefox-normal", "edge-normal"],
|
||||
)
|
||||
async def test_scrape_web_page(browser_type, use_proxy, url, urls, proxy, capfd):
|
||||
async def test_scrape_web_page(browser_type, use_proxy, proxy, capfd, http_server):
|
||||
# Prerequisites
|
||||
# firefox, chrome, Microsoft Edge
|
||||
server, url = await http_server()
|
||||
urls = [url, url, url]
|
||||
proxy_url = None
|
||||
if use_proxy:
|
||||
server, proxy_url = await proxy()
|
||||
proxy_server, proxy_url = await proxy()
|
||||
browser = web_browser_engine_selenium.SeleniumWrapper(browser_type=browser_type, proxy=proxy_url)
|
||||
result = await browser.run(url)
|
||||
assert isinstance(result, WebPage)
|
||||
assert "MetaGPT" in result.inner_text
|
||||
|
||||
if urls:
|
||||
results = await browser.run(url, *urls)
|
||||
assert isinstance(results, list)
|
||||
assert len(results) == len(urls) + 1
|
||||
assert all(("MetaGPT" in i.inner_text) for i in results)
|
||||
results = await browser.run(url, *urls)
|
||||
assert isinstance(results, list)
|
||||
assert len(results) == len(urls) + 1
|
||||
assert all(("MetaGPT" in i.inner_text) for i in results)
|
||||
if use_proxy:
|
||||
server.close()
|
||||
assert "Proxy:" in capfd.readouterr().out
|
||||
proxy_server.close()
|
||||
await proxy_server.wait_closed()
|
||||
assert "Proxy: localhost" in capfd.readouterr().out
|
||||
await server.stop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -13,7 +13,6 @@ import uuid
|
|||
from pathlib import Path
|
||||
from typing import Any, Set
|
||||
|
||||
import aiofiles
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
|
@ -125,9 +124,7 @@ class TestGetProjectRoot:
|
|||
async def test_parse_data_exception(self, filename, want):
|
||||
pathname = Path(__file__).parent.parent.parent / "data/output_parser" / filename
|
||||
assert pathname.exists()
|
||||
async with aiofiles.open(str(pathname), mode="r") as reader:
|
||||
data = await reader.read()
|
||||
|
||||
data = await aread(filename=pathname)
|
||||
result = OutputParser.parse_data(data=data)
|
||||
assert want in result
|
||||
|
||||
|
|
@ -178,7 +175,7 @@ class TestGetProjectRoot:
|
|||
],
|
||||
)
|
||||
def test_split_namespace(self, val, want):
|
||||
res = split_namespace(val)
|
||||
res = split_namespace(val, maxsplit=-1)
|
||||
assert res == want
|
||||
|
||||
def test_read_json_file(self):
|
||||
|
|
@ -198,12 +195,25 @@ class TestGetProjectRoot:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_write(self):
|
||||
pathname = Path(__file__).parent / uuid.uuid4().hex / "test.tmp"
|
||||
pathname = Path(__file__).parent / f"../../../workspace/unittest/{uuid.uuid4().hex}" / "test.tmp"
|
||||
await awrite(pathname, "ABC")
|
||||
data = await aread(pathname)
|
||||
assert data == "ABC"
|
||||
pathname.unlink(missing_ok=True)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_write_error_charset(self):
|
||||
pathname = Path(__file__).parent / f"../../../workspace/unittest/{uuid.uuid4().hex}" / "test.txt"
|
||||
content = "中国abc123\u27f6"
|
||||
await awrite(filename=pathname, data=content)
|
||||
data = await aread(filename=pathname)
|
||||
assert data == content
|
||||
|
||||
content = "GB18030 是中国国家标准局发布的新一代中文字符集标准,是 GBK 的升级版,支持更广泛的字符范围。"
|
||||
await awrite(filename=pathname, data=content, encoding="gb2312")
|
||||
data = await aread(filename=pathname, encoding="utf-8")
|
||||
assert data == content
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
|
|
@ -10,15 +10,14 @@
|
|||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import aiofiles
|
||||
import pytest
|
||||
|
||||
from metagpt.utils.common import awrite
|
||||
from metagpt.utils.git_repository import GitRepository
|
||||
|
||||
|
||||
async def mock_file(filename, content=""):
|
||||
async with aiofiles.open(str(filename), mode="w") as file:
|
||||
await file.write(content)
|
||||
await awrite(filename=filename, data=content)
|
||||
|
||||
|
||||
async def mock_repo(local_path) -> (GitRepository, Path):
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ from metagpt.utils.mermaid import MMC1, mermaid_to_file
|
|||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("engine", ["nodejs", "ink"]) # TODO: playwright and pyppeteer
|
||||
async def test_mermaid(engine, context):
|
||||
async def test_mermaid(engine, context, mermaid_mocker):
|
||||
# nodejs prerequisites: npm install -g @mermaid-js/mermaid-cli
|
||||
# ink prerequisites: connected to internet
|
||||
# playwright prerequisites: playwright install --with-deps chromium
|
||||
|
|
|
|||
|
|
@ -211,6 +211,11 @@ value
|
|||
output = repair_invalid_json(output, "Expecting ',' delimiter: line 4 column 1")
|
||||
assert output == target_output
|
||||
|
||||
raw_output = '{"key": "url "http" \\"https\\" "}'
|
||||
target_output = '{"key": "url \\"http\\" \\"https\\" "}'
|
||||
output = repair_invalid_json(raw_output, "Expecting ',' delimiter: line 1 column 15 (char 14)")
|
||||
assert output == target_output
|
||||
|
||||
|
||||
def test_retry_parse_json_text():
|
||||
from metagpt.utils.repair_llm_raw_output import retry_parse_json_text
|
||||
|
|
|
|||
25
tests/metagpt/utils/test_repo_to_markdown.py
Normal file
25
tests/metagpt/utils/test_repo_to_markdown.py
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.utils.repo_to_markdown import repo_to_markdown
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["repo_path", "output"],
|
||||
[(Path(__file__).parent.parent, Path(__file__).parent.parent.parent / f"workspace/unittest/{uuid.uuid4().hex}.md")],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_repo_to_markdown(repo_path: Path, output: Path):
|
||||
markdown = await repo_to_markdown(repo_path=repo_path, output=output)
|
||||
assert output.exists()
|
||||
assert markdown
|
||||
|
||||
output.unlink(missing_ok=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
@ -9,10 +9,10 @@ import uuid
|
|||
from pathlib import Path
|
||||
|
||||
import aioboto3
|
||||
import aiofiles
|
||||
import pytest
|
||||
|
||||
from metagpt.config2 import Config
|
||||
from metagpt.configs.s3_config import S3Config
|
||||
from metagpt.utils.common import aread
|
||||
from metagpt.utils.s3 import S3
|
||||
|
||||
|
|
@ -30,6 +30,14 @@ async def test_s3(mocker):
|
|||
mock_client.__aenter__.return_value = mock_client
|
||||
mock_client.__aexit__.return_value = None
|
||||
mocker.patch.object(aioboto3.Session, "client", return_value=mock_client)
|
||||
mock_config = mocker.Mock()
|
||||
mock_config.s3 = S3Config(
|
||||
access_key="mock_access_key",
|
||||
secret_key="mock_secret_key",
|
||||
endpoint="http://mock.endpoint",
|
||||
bucket="mock_bucket",
|
||||
)
|
||||
mocker.patch.object(Config, "default", return_value=mock_config)
|
||||
|
||||
# Prerequisites
|
||||
s3 = Config.default().s3
|
||||
|
|
@ -37,7 +45,7 @@ async def test_s3(mocker):
|
|||
conn = S3(s3)
|
||||
object_name = "unittest.bak"
|
||||
await conn.upload_file(bucket=s3.bucket, local_path=__file__, object_name=object_name)
|
||||
pathname = (Path(__file__).parent / uuid.uuid4().hex).with_suffix(".bak")
|
||||
pathname = (Path(__file__).parent / "../../../workspace/unittest" / uuid.uuid4().hex).with_suffix(".bak")
|
||||
pathname.unlink(missing_ok=True)
|
||||
await conn.download_file(bucket=s3.bucket, object_name=object_name, local_path=str(pathname))
|
||||
assert pathname.exists()
|
||||
|
|
@ -45,8 +53,7 @@ async def test_s3(mocker):
|
|||
assert url
|
||||
bin_data = await conn.get_object(bucket=s3.bucket, object_name=object_name)
|
||||
assert bin_data
|
||||
async with aiofiles.open(__file__, mode="r", encoding="utf-8") as reader:
|
||||
data = await reader.read()
|
||||
data = await aread(filename=__file__)
|
||||
res = await conn.cache(data, ".bak", "script")
|
||||
assert "http" in res
|
||||
|
||||
|
|
@ -60,8 +67,6 @@ async def test_s3(mocker):
|
|||
except Exception:
|
||||
pass
|
||||
|
||||
await reader.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@
|
|||
import nbformat
|
||||
import pytest
|
||||
|
||||
from metagpt.actions.ci.execute_nb_code import ExecuteNbCode
|
||||
from metagpt.actions.di.execute_nb_code import ExecuteNbCode
|
||||
from metagpt.utils.common import read_json_file
|
||||
from metagpt.utils.save_code import DATA_PATH, save_code_file
|
||||
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ def _paragraphs(n):
|
|||
@pytest.mark.parametrize(
|
||||
"msgs, model_name, system_text, reserved, expected",
|
||||
[
|
||||
(_msgs(), "gpt-3.5-turbo", "System", 1500, 1),
|
||||
(_msgs(), "gpt-3.5-turbo-0613", "System", 1500, 1),
|
||||
(_msgs(), "gpt-3.5-turbo-16k", "System", 3000, 6),
|
||||
(_msgs(), "gpt-3.5-turbo-16k", "Hello," * 1000, 3000, 5),
|
||||
(_msgs(), "gpt-4", "System", 2000, 3),
|
||||
|
|
@ -32,21 +32,23 @@ def _paragraphs(n):
|
|||
],
|
||||
)
|
||||
def test_reduce_message_length(msgs, model_name, system_text, reserved, expected):
|
||||
assert len(reduce_message_length(msgs, model_name, system_text, reserved)) / (len("Hello,")) / 1000 == expected
|
||||
length = len(reduce_message_length(msgs, model_name, system_text, reserved)) / (len("Hello,")) / 1000
|
||||
assert length == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"text, prompt_template, model_name, system_text, reserved, expected",
|
||||
[
|
||||
(" ".join("Hello World." for _ in range(1000)), "Prompt: {}", "gpt-3.5-turbo", "System", 1500, 2),
|
||||
(" ".join("Hello World." for _ in range(1000)), "Prompt: {}", "gpt-3.5-turbo-0613", "System", 1500, 2),
|
||||
(" ".join("Hello World." for _ in range(1000)), "Prompt: {}", "gpt-3.5-turbo-16k", "System", 3000, 1),
|
||||
(" ".join("Hello World." for _ in range(4000)), "Prompt: {}", "gpt-4", "System", 2000, 2),
|
||||
(" ".join("Hello World." for _ in range(8000)), "Prompt: {}", "gpt-4-32k", "System", 4000, 1),
|
||||
(" ".join("Hello World" for _ in range(8000)), "Prompt: {}", "gpt-3.5-turbo-0613", "System", 1000, 8),
|
||||
],
|
||||
)
|
||||
def test_generate_prompt_chunk(text, prompt_template, model_name, system_text, reserved, expected):
|
||||
ret = list(generate_prompt_chunk(text, prompt_template, model_name, system_text, reserved))
|
||||
assert len(ret) == expected
|
||||
chunk = len(list(generate_prompt_chunk(text, prompt_template, model_name, system_text, reserved)))
|
||||
assert chunk == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
|
|
|||
64
tests/metagpt/utils/test_tree.py
Normal file
64
tests/metagpt/utils/test_tree.py
Normal file
|
|
@ -0,0 +1,64 @@
|
|||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.utils.tree import _print_tree, tree
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("root", "rules"),
|
||||
[
|
||||
(str(Path(__file__).parent / "../.."), None),
|
||||
(str(Path(__file__).parent / "../.."), str(Path(__file__).parent / "../../../.gitignore")),
|
||||
],
|
||||
)
|
||||
def test_tree(root: str, rules: str):
|
||||
v = tree(root=root, gitignore=rules)
|
||||
assert v
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("root", "rules"),
|
||||
[
|
||||
(str(Path(__file__).parent / "../.."), None),
|
||||
(str(Path(__file__).parent / "../.."), str(Path(__file__).parent / "../../../.gitignore")),
|
||||
],
|
||||
)
|
||||
def test_tree_command(root: str, rules: str):
|
||||
v = tree(root=root, gitignore=rules, run_command=True)
|
||||
assert v
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("tree", "want"),
|
||||
[
|
||||
({"a": {"b": {}, "c": {}}}, ["a", "+-- b", "+-- c"]),
|
||||
({"a": {"b": {}, "c": {"d": {}}}}, ["a", "+-- b", "+-- c", " +-- d"]),
|
||||
(
|
||||
{"a": {"b": {"e": {"f": {}, "g": {}}}, "c": {"d": {}}}},
|
||||
["a", "+-- b", "| +-- e", "| +-- f", "| +-- g", "+-- c", " +-- d"],
|
||||
),
|
||||
(
|
||||
{"h": {"a": {"b": {"e": {"f": {}, "g": {}}}, "c": {"d": {}}}, "i": {}}},
|
||||
[
|
||||
"h",
|
||||
"+-- a",
|
||||
"| +-- b",
|
||||
"| | +-- e",
|
||||
"| | +-- f",
|
||||
"| | +-- g",
|
||||
"| +-- c",
|
||||
"| +-- d",
|
||||
"+-- i",
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test__print_tree(tree: dict, want: List[str]):
|
||||
v = _print_tree(tree)
|
||||
assert v == want
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
45
tests/metagpt/utils/test_visual_graph_repo.py
Normal file
45
tests/metagpt/utils/test_visual_graph_repo.py
Normal file
|
|
@ -0,0 +1,45 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2024/1/4
|
||||
@Author : mashenquan
|
||||
@File : test_visual_graph_repo.py
|
||||
@Desc : Unit tests for testing and demonstrating the usage of VisualDiGraphRepo.
|
||||
"""
|
||||
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.utils.common import remove_affix, split_namespace
|
||||
from metagpt.utils.visual_graph_repo import VisualDiGraphRepo
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_visual_di_graph_repo(context, mocker):
|
||||
filename = Path(__file__).parent / "../../data/graph_db/networkx.sequence_view.json"
|
||||
repo = await VisualDiGraphRepo.load_from(filename=filename)
|
||||
|
||||
class_view = await repo.get_mermaid_class_view()
|
||||
assert class_view
|
||||
await context.repo.resources.graph_repo.save(filename="class_view.md", content=f"```mermaid\n{class_view}\n```\n")
|
||||
|
||||
sequence_views = await repo.get_mermaid_sequence_views()
|
||||
assert sequence_views
|
||||
for ns, sqv in sequence_views:
|
||||
filename = re.sub(r"[:/\\\.]+", "_", ns) + ".sequence_view.md"
|
||||
sqv = sqv.strip(" `")
|
||||
await context.repo.resources.graph_repo.save(filename=filename, content=f"```mermaid\n{sqv}\n```\n")
|
||||
|
||||
sequence_view_vers = await repo.get_mermaid_sequence_view_versions()
|
||||
assert sequence_view_vers
|
||||
for ns, sqv in sequence_view_vers:
|
||||
ver, sqv = split_namespace(sqv)
|
||||
filename = re.sub(r"[:/\\\.]+", "_", ns) + f".{ver}.sequence_view_ver.md"
|
||||
sqv = remove_affix(sqv).strip(" `")
|
||||
await context.repo.resources.graph_repo.save(filename=filename, content=f"```mermaid\n{sqv}\n```\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
@ -10,6 +10,7 @@ class MockAioResponse:
|
|||
check_funcs: dict[tuple[str, str], Callable[[dict], str]] = {}
|
||||
rsp_cache: dict[str, str] = {}
|
||||
name = "aiohttp"
|
||||
status = 200
|
||||
|
||||
def __init__(self, session, method, url, **kwargs) -> None:
|
||||
fn = self.check_funcs.get((method, url))
|
||||
|
|
@ -22,6 +23,7 @@ class MockAioResponse:
|
|||
async def __aenter__(self):
|
||||
if self.response:
|
||||
await self.response.__aenter__()
|
||||
self.status = self.response.status
|
||||
elif self.mng:
|
||||
self.response = await self.mng.__aenter__()
|
||||
return self
|
||||
|
|
@ -41,6 +43,17 @@ class MockAioResponse:
|
|||
self.rsp_cache[self.key] = data
|
||||
return data
|
||||
|
||||
@property
|
||||
def content(self):
|
||||
return self
|
||||
|
||||
async def read(self):
|
||||
if self.key in self.rsp_cache:
|
||||
return eval(self.rsp_cache[self.key])
|
||||
data = await self.response.content.read()
|
||||
self.rsp_cache[self.key] = str(data)
|
||||
return data
|
||||
|
||||
def raise_for_status(self):
|
||||
if self.response:
|
||||
self.response.raise_for_status()
|
||||
|
|
|
|||
|
|
@ -3,8 +3,9 @@ from typing import Optional, Union
|
|||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.configs.llm_config import LLMType
|
||||
from metagpt.logs import log_llm_stream, logger
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.azure_openai_api import AzureOpenAILLM
|
||||
from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA
|
||||
from metagpt.provider.openai_api import OpenAILLM
|
||||
from metagpt.schema import Message
|
||||
|
||||
|
|
@ -24,31 +25,21 @@ class MockLLM(OriginalLLM):
|
|||
async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str:
|
||||
"""Overwrite original acompletion_text to cancel retry"""
|
||||
if stream:
|
||||
resp = self._achat_completion_stream(messages, timeout=timeout)
|
||||
|
||||
collected_messages = []
|
||||
async for i in resp:
|
||||
log_llm_stream(i)
|
||||
collected_messages.append(i)
|
||||
|
||||
full_reply_content = "".join(collected_messages)
|
||||
usage = self._calc_usage(messages, full_reply_content)
|
||||
self._update_costs(usage)
|
||||
return full_reply_content
|
||||
resp = await self._achat_completion_stream(messages, timeout=timeout)
|
||||
return resp
|
||||
|
||||
rsp = await self._achat_completion(messages, timeout=timeout)
|
||||
return self.get_choice_text(rsp)
|
||||
|
||||
async def original_aask(
|
||||
self,
|
||||
msg: str,
|
||||
msg: Union[str, list[dict[str, str]]],
|
||||
system_msgs: Optional[list[str]] = None,
|
||||
format_msgs: Optional[list[dict[str, str]]] = None,
|
||||
images: Optional[Union[str, list[str]]] = None,
|
||||
timeout=3,
|
||||
stream=True,
|
||||
):
|
||||
"""A copy of metagpt.provider.base_llm.BaseLLM.aask, we can't use super().aask because it will be mocked"""
|
||||
) -> str:
|
||||
if system_msgs:
|
||||
message = self._system_msgs(system_msgs)
|
||||
else:
|
||||
|
|
@ -57,7 +48,11 @@ class MockLLM(OriginalLLM):
|
|||
message = []
|
||||
if format_msgs:
|
||||
message.extend(format_msgs)
|
||||
message.append(self._user_msg(msg, images=images))
|
||||
if isinstance(msg, str):
|
||||
message.append(self._user_msg(msg, images=images))
|
||||
else:
|
||||
message.extend(msg)
|
||||
logger.debug(message)
|
||||
rsp = await self.acompletion_text(message, stream=stream, timeout=timeout)
|
||||
return rsp
|
||||
|
||||
|
|
@ -76,19 +71,27 @@ class MockLLM(OriginalLLM):
|
|||
A copy of metagpt.provider.openai_api.OpenAILLM.aask_code, we can't use super().aask because it will be mocked.
|
||||
Since openai_api.OpenAILLM.aask_code is different from base_llm.BaseLLM.aask_code, we use the former.
|
||||
"""
|
||||
if "tools" not in kwargs:
|
||||
configs = {"tools": [{"type": "function", "function": GENERAL_FUNCTION_SCHEMA}]}
|
||||
kwargs.update(configs)
|
||||
rsp = await self._achat_completion_function(messages, **kwargs)
|
||||
return self.get_choice_function_arguments(rsp)
|
||||
|
||||
async def aask(
|
||||
self,
|
||||
msg: str,
|
||||
msg: Union[str, list[dict[str, str]]],
|
||||
system_msgs: Optional[list[str]] = None,
|
||||
format_msgs: Optional[list[dict[str, str]]] = None,
|
||||
images: Optional[Union[str, list[str]]] = None,
|
||||
timeout=3,
|
||||
stream=True,
|
||||
) -> str:
|
||||
msg_key = msg # used to identify it a message has been called before
|
||||
# used to identify it a message has been called before
|
||||
if isinstance(msg, list):
|
||||
msg_key = "#MSG_SEP#".join([m["content"] for m in msg])
|
||||
else:
|
||||
msg_key = msg
|
||||
|
||||
if system_msgs:
|
||||
joined_system_msg = "#MSG_SEP#".join(system_msgs) + "#SYSTEM_MSG_END#"
|
||||
msg_key = joined_system_msg + msg_key
|
||||
|
|
@ -101,8 +104,7 @@ class MockLLM(OriginalLLM):
|
|||
return rsp
|
||||
|
||||
async def aask_code(self, messages: Union[str, Message, list[dict]], **kwargs) -> dict:
|
||||
messages = self._process_message(messages)
|
||||
msg_key = json.dumps(messages, ensure_ascii=False)
|
||||
msg_key = json.dumps(self.format_msg(messages), ensure_ascii=False)
|
||||
rsp = await self._mock_rsp(msg_key, self.original_aask_code, messages, **kwargs)
|
||||
return rsp
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +0,0 @@
|
|||
llm:
|
||||
api_type: "spark"
|
||||
app_id: "xxx"
|
||||
api_key: "xxx"
|
||||
api_secret: "xxx"
|
||||
domain: "generalv2"
|
||||
base_url: "wss://spark-api.xf-yun.com/v3.1/chat"
|
||||
Loading…
Add table
Add a link
Reference in a new issue