mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-18 13:55:17 +02:00
Merge branch 'main' into feat_memory
This commit is contained in:
commit
d7968fdc20
160 changed files with 2194 additions and 2096 deletions
|
|
@ -1,51 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Date : 1/11/2024 8:51 PM
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Desc :
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.actions.di.debug_code import DebugCode
|
||||
from metagpt.schema import Message
|
||||
|
||||
ErrorStr = """Tested passed:
|
||||
|
||||
Tests failed:
|
||||
assert sort_array([1, 5, 2, 3, 4]) == [1, 2, 3, 4, 5] # output: [1, 2, 4, 3, 5]
|
||||
"""
|
||||
|
||||
CODE = """
|
||||
def sort_array(arr):
|
||||
# Helper function to count the number of ones in the binary representation
|
||||
def count_ones(n):
|
||||
return bin(n).count('1')
|
||||
|
||||
# Sort the array using a custom key function
|
||||
# The key function returns a tuple (number of ones, value) for each element
|
||||
# This ensures that if two elements have the same number of ones, they are sorted by their value
|
||||
sorted_arr = sorted(arr, key=lambda x: (count_ones(x), x))
|
||||
|
||||
return sorted_arr
|
||||
```
|
||||
"""
|
||||
|
||||
DebugContext = '''Solve the problem in Python:
|
||||
def sort_array(arr):
|
||||
"""
|
||||
In this Kata, you have to sort an array of non-negative integers according to
|
||||
number of ones in their binary representation in ascending order.
|
||||
For similar number of ones, sort based on decimal value.
|
||||
|
||||
It must be implemented like this:
|
||||
>>> sort_array([1, 5, 2, 3, 4]) == [1, 2, 3, 4, 5]
|
||||
>>> sort_array([-2, -3, -4, -5, -6]) == [-6, -5, -4, -3, -2]
|
||||
>>> sort_array([1, 0, 2, 3, 4]) [0, 1, 2, 3, 4]
|
||||
"""
|
||||
'''
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_debug_code():
|
||||
debug_context = Message(content=DebugContext)
|
||||
new_code = await DebugCode().run(context=debug_context, code=CODE, runtime_result=ErrorStr)
|
||||
assert "def sort_array(arr)" in new_code["code"]
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.actions.di.execute_nb_code import ExecuteNbCode, truncate
|
||||
from metagpt.actions.di.execute_nb_code import ExecuteNbCode
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -8,6 +8,7 @@ async def test_code_running():
|
|||
executor = ExecuteNbCode()
|
||||
output, is_success = await executor.run("print('hello world!')")
|
||||
assert is_success
|
||||
await executor.terminate()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -17,6 +18,7 @@ async def test_split_code_running():
|
|||
_ = await executor.run("z=x+y")
|
||||
output, is_success = await executor.run("assert z==3")
|
||||
assert is_success
|
||||
await executor.terminate()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -24,6 +26,7 @@ async def test_execute_error():
|
|||
executor = ExecuteNbCode()
|
||||
output, is_success = await executor.run("z=1/0")
|
||||
assert not is_success
|
||||
await executor.terminate()
|
||||
|
||||
|
||||
PLOT_CODE = """
|
||||
|
|
@ -52,21 +55,7 @@ async def test_plotting_code():
|
|||
executor = ExecuteNbCode()
|
||||
output, is_success = await executor.run(PLOT_CODE)
|
||||
assert is_success
|
||||
|
||||
|
||||
def test_truncate():
|
||||
# 代码执行成功
|
||||
output, is_success = truncate("hello world", 5, True)
|
||||
assert "Truncated to show only first 5 characters\nhello" in output
|
||||
assert is_success
|
||||
# 代码执行失败
|
||||
output, is_success = truncate("hello world", 5, False)
|
||||
assert "Truncated to show only last 5 characters\nworld" in output
|
||||
assert not is_success
|
||||
# 异步
|
||||
output, is_success = truncate("<coroutine object", 5, True)
|
||||
assert not is_success
|
||||
assert "await" in output
|
||||
await executor.terminate()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -76,6 +65,7 @@ async def test_run_with_timeout():
|
|||
message, success = await executor.run(code)
|
||||
assert not success
|
||||
assert message.startswith("Cell execution timed out")
|
||||
await executor.terminate()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -83,7 +73,7 @@ async def test_run_code_text():
|
|||
executor = ExecuteNbCode()
|
||||
message, success = await executor.run(code='print("This is a code!")', language="python")
|
||||
assert success
|
||||
assert message == "This is a code!\n"
|
||||
assert "This is a code!" in message
|
||||
message, success = await executor.run(code="# This is a code!", language="markdown")
|
||||
assert success
|
||||
assert message == "# This is a code!"
|
||||
|
|
@ -91,19 +81,22 @@ async def test_run_code_text():
|
|||
message, success = await executor.run(code=mix_text, language="markdown")
|
||||
assert success
|
||||
assert message == mix_text
|
||||
await executor.terminate()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_terminate():
|
||||
executor = ExecuteNbCode()
|
||||
await executor.run(code='print("This is a code!")', language="python")
|
||||
is_kernel_alive = await executor.nb_client.km.is_alive()
|
||||
assert is_kernel_alive
|
||||
await executor.terminate()
|
||||
import time
|
||||
|
||||
time.sleep(2)
|
||||
assert executor.nb_client.km is None
|
||||
@pytest.mark.parametrize(
|
||||
"k", [(1), (5)]
|
||||
) # k=1 to test a single regular terminate, k>1 to test terminate under continuous run
|
||||
async def test_terminate(k):
|
||||
for _ in range(k):
|
||||
executor = ExecuteNbCode()
|
||||
await executor.run(code='print("This is a code!")', language="python")
|
||||
is_kernel_alive = await executor.nb_client.km.is_alive()
|
||||
assert is_kernel_alive
|
||||
await executor.terminate()
|
||||
assert executor.nb_client.km is None
|
||||
assert executor.nb_client.kc is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -114,3 +107,22 @@ async def test_reset():
|
|||
assert is_kernel_alive
|
||||
await executor.reset()
|
||||
assert executor.nb_client.km is None
|
||||
await executor.terminate()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_outputs():
|
||||
executor = ExecuteNbCode()
|
||||
code = """
|
||||
import pandas as pd
|
||||
df = pd.DataFrame({'ID': [1,2,3], 'NAME': ['a', 'b', 'c']})
|
||||
print(df.columns)
|
||||
print(f"columns num:{len(df.columns)}")
|
||||
print(df['DUMMPY_ID'])
|
||||
"""
|
||||
output, is_success = await executor.run(code)
|
||||
assert not is_success
|
||||
assert "Index(['ID', 'NAME'], dtype='object')" in output
|
||||
assert "KeyError: 'DUMMPY_ID'" in output
|
||||
assert "columns num:2" in output
|
||||
await executor.terminate()
|
||||
|
|
|
|||
|
|
@ -1,46 +0,0 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.actions.di.ml_action import WriteCodeWithToolsML
|
||||
from metagpt.schema import Plan, Task
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_code_with_tools():
|
||||
write_code_ml = WriteCodeWithToolsML()
|
||||
|
||||
task_map = {
|
||||
"1": Task(
|
||||
task_id="1",
|
||||
instruction="随机生成一个pandas DataFrame数据集",
|
||||
task_type="other",
|
||||
dependent_task_ids=[],
|
||||
code="""
|
||||
import pandas as pd
|
||||
df = pd.DataFrame({
|
||||
'a': [1, 2, 3, 4, 5],
|
||||
'b': [1.1, 2.2, 3.3, 4.4, np.nan],
|
||||
'c': ['aa', 'bb', 'cc', 'dd', 'ee'],
|
||||
'd': [1, 2, 3, 4, 5]
|
||||
})
|
||||
""",
|
||||
is_finished=True,
|
||||
),
|
||||
"2": Task(
|
||||
task_id="2",
|
||||
instruction="对数据集进行数据清洗",
|
||||
task_type="data_preprocess",
|
||||
dependent_task_ids=["1"],
|
||||
),
|
||||
}
|
||||
plan = Plan(
|
||||
goal="构造数据集并进行数据清洗",
|
||||
tasks=list(task_map.values()),
|
||||
task_map=task_map,
|
||||
current_task_id="2",
|
||||
)
|
||||
column_info = ""
|
||||
|
||||
_, code_with_ml = await write_code_ml.run([], plan, column_info)
|
||||
code_with_ml = code_with_ml["code"]
|
||||
assert len(code_with_ml) > 0
|
||||
print(code_with_ml)
|
||||
|
|
@ -1,134 +1,61 @@
|
|||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.actions.di.execute_nb_code import ExecuteNbCode
|
||||
from metagpt.actions.di.write_analysis_code import (
|
||||
WriteCodeWithoutTools,
|
||||
WriteCodeWithTools,
|
||||
)
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import Message, Plan, Task
|
||||
from metagpt.strategy.planner import STRUCTURAL_CONTEXT
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_code_by_list_plan():
|
||||
write_code = WriteCodeWithoutTools()
|
||||
execute_code = ExecuteNbCode()
|
||||
messages = []
|
||||
plan = ["随机生成一个pandas DataFrame时间序列", "绘制这个时间序列的直方图", "回顾已完成的任务", "求均值", "总结"]
|
||||
for task in plan:
|
||||
print(f"\n任务: {task}\n\n")
|
||||
messages.append(Message(task, role="assistant"))
|
||||
code = await write_code.run(messages)
|
||||
if task.startswith(("回顾", "总结")):
|
||||
assert code["language"] == "markdown"
|
||||
else:
|
||||
assert code["language"] == "python"
|
||||
messages.append(Message(code["code"], role="assistant"))
|
||||
assert len(code) > 0
|
||||
output, _ = await execute_code.run(**code)
|
||||
print(f"\n[Output]: 任务{task}的执行结果是: \n{output}\n")
|
||||
messages.append(output)
|
||||
from metagpt.actions.di.write_analysis_code import WriteAnalysisCode
|
||||
from metagpt.schema import Message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_recommendation():
|
||||
task = "clean and preprocess the data"
|
||||
available_tools = {
|
||||
"FillMissingValue": "Filling missing values",
|
||||
"SplitBins": "Bin continuous data into intervals and return the bin identifier encoded as an integer value",
|
||||
}
|
||||
write_code = WriteCodeWithTools()
|
||||
tools = await write_code._recommend_tool(task, available_tools)
|
||||
async def test_write_code_with_plan():
|
||||
write_code = WriteAnalysisCode()
|
||||
|
||||
assert len(tools) == 1
|
||||
assert "FillMissingValue" in tools
|
||||
user_requirement = "Run data analysis on sklearn Iris dataset, include a plot"
|
||||
plan_status = "\n## Finished Tasks\n### code\n```python\n\n```\n\n### execution result\n\n\n## Current Task\nLoad the sklearn Iris dataset and perform exploratory data analysis\n\n## Task Guidance\nWrite complete code for 'Current Task'. And avoid duplicating code from 'Finished Tasks', such as repeated import of packages, reading data, etc.\nSpecifically, \nThe current task is about exploratory data analysis, please note the following:\n- Distinguish column types with `select_dtypes` for tailored analysis and visualization, such as correlation.\n- Remember to `import numpy as np` before using Numpy functions.\n\n"
|
||||
|
||||
code = await write_code.run(user_requirement=user_requirement, plan_status=plan_status)
|
||||
assert len(code) > 0
|
||||
assert "sklearn" in code
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_code_with_tools():
|
||||
write_code = WriteCodeWithTools()
|
||||
write_code = WriteAnalysisCode()
|
||||
|
||||
requirement = "构造数据集并进行数据清洗"
|
||||
task_map = {
|
||||
"1": Task(
|
||||
task_id="1",
|
||||
instruction="随机生成一个pandas DataFrame数据集",
|
||||
task_type="other",
|
||||
dependent_task_ids=[],
|
||||
code="""
|
||||
import pandas as pd
|
||||
df = pd.DataFrame({
|
||||
'a': [1, 2, 3, 4, 5],
|
||||
'b': [1.1, 2.2, 3.3, 4.4, np.nan],
|
||||
'c': ['aa', 'bb', 'cc', 'dd', 'ee'],
|
||||
'd': [1, 2, 3, 4, 5]
|
||||
})
|
||||
""",
|
||||
is_finished=True,
|
||||
),
|
||||
"2": Task(
|
||||
task_id="2",
|
||||
instruction="对数据集进行数据清洗",
|
||||
task_type="data_preprocess",
|
||||
dependent_task_ids=["1"],
|
||||
),
|
||||
}
|
||||
plan = Plan(
|
||||
goal="构造数据集并进行数据清洗",
|
||||
tasks=list(task_map.values()),
|
||||
task_map=task_map,
|
||||
current_task_id="2",
|
||||
)
|
||||
user_requirement = "Preprocess sklearn Wine recognition dataset and train a model to predict wine class (20% as validation), and show validation accuracy."
|
||||
tool_info = """
|
||||
## Capabilities
|
||||
- You can utilize pre-defined tools in any code lines from 'Available Tools' in the form of Python class or function.
|
||||
- You can freely combine the use of any other public packages, like sklearn, numpy, pandas, etc..
|
||||
|
||||
context = STRUCTURAL_CONTEXT.format(
|
||||
user_requirement=requirement,
|
||||
context=plan.context,
|
||||
tasks=list(task_map.values()),
|
||||
current_task=plan.current_task.model_dump_json(),
|
||||
)
|
||||
context_msg = [Message(content=context, role="user")]
|
||||
## Available Tools:
|
||||
Each tool is described in JSON format. When you call a tool, import the tool from its path first.
|
||||
{'FillMissingValue': {'type': 'class', 'description': 'Completing missing values with simple strategies.', 'methods': {'__init__': {'type': 'function', 'description': 'Initialize self. ', 'signature': '(self, features: \'list\', strategy: "Literal[\'mean\', \'median\', \'most_frequent\', \'constant\']" = \'mean\', fill_value=None)', 'parameters': 'Args: features (list): Columns to be processed. strategy (Literal["mean", "median", "most_frequent", "constant"], optional): The imputation strategy, notice \'mean\' and \'median\' can only be used for numeric features. Defaults to \'mean\'. fill_value (int, optional): Fill_value is used to replace all occurrences of missing_values. Defaults to None.'}, 'fit': {'type': 'function', 'description': 'Fit a model to be used in subsequent transform. ', 'signature': "(self, df: 'pd.DataFrame')", 'parameters': 'Args: df (pd.DataFrame): The input DataFrame.'}, 'fit_transform': {'type': 'function', 'description': 'Fit and transform the input DataFrame. ', 'signature': "(self, df: 'pd.DataFrame') -> 'pd.DataFrame'", 'parameters': 'Args: df (pd.DataFrame): The input DataFrame. Returns: pd.DataFrame: The transformed DataFrame.'}, 'transform': {'type': 'function', 'description': 'Transform the input DataFrame with the fitted model. ', 'signature': "(self, df: 'pd.DataFrame') -> 'pd.DataFrame'", 'parameters': 'Args: df (pd.DataFrame): The input DataFrame. Returns: pd.DataFrame: The transformed DataFrame.'}}, 'tool_path': 'metagpt/tools/libs/data_preprocess.py'}
|
||||
"""
|
||||
|
||||
code = await write_code.run(context_msg, plan)
|
||||
code = code["code"]
|
||||
code = await write_code.run(user_requirement=user_requirement, tool_info=tool_info)
|
||||
assert len(code) > 0
|
||||
print(code)
|
||||
assert "metagpt.tools.libs" in code
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_code_to_correct_error():
|
||||
structural_context = """
|
||||
## User Requirement
|
||||
read a dataset test.csv and print its head
|
||||
## Current Plan
|
||||
[
|
||||
{
|
||||
"task_id": "1",
|
||||
"dependent_task_ids": [],
|
||||
"instruction": "import pandas and load the dataset from 'test.csv'.",
|
||||
"task_type": "",
|
||||
"code": "",
|
||||
"result": "",
|
||||
"is_finished": false
|
||||
},
|
||||
{
|
||||
"task_id": "2",
|
||||
"dependent_task_ids": [
|
||||
"1"
|
||||
],
|
||||
"instruction": "Print the head of the dataset to display the first few rows.",
|
||||
"task_type": "",
|
||||
"code": "",
|
||||
"result": "",
|
||||
"is_finished": false
|
||||
}
|
||||
]
|
||||
async def test_debug_with_reflection():
|
||||
user_requirement = "read a dataset test.csv and print its head"
|
||||
|
||||
plan_status = """
|
||||
## Finished Tasks
|
||||
### code
|
||||
```python
|
||||
```
|
||||
|
||||
### execution result
|
||||
|
||||
## Current Task
|
||||
{"task_id": "1", "dependent_task_ids": [], "instruction": "import pandas and load the dataset from 'test.csv'.", "task_type": "", "code": "", "result": "", "is_finished": false}
|
||||
import pandas and load the dataset from 'test.csv'.
|
||||
|
||||
## Task Guidance
|
||||
Write complete code for 'Current Task'. And avoid duplicating code from 'Finished Tasks', such as repeated import of packages, reading data, etc.
|
||||
Specifically,
|
||||
"""
|
||||
|
||||
wrong_code = """import pandas as pd\ndata = pd.read_excel('test.csv')\ndata""" # use read_excel to read a csv
|
||||
error = """
|
||||
Traceback (most recent call last):
|
||||
|
|
@ -139,186 +66,14 @@ async def test_write_code_to_correct_error():
|
|||
raise ValueError(
|
||||
ValueError: Excel file format cannot be determined, you must specify an engine manually.
|
||||
"""
|
||||
context = [
|
||||
Message(content=structural_context, role="user"),
|
||||
working_memory = [
|
||||
Message(content=wrong_code, role="assistant"),
|
||||
Message(content=error, role="user"),
|
||||
]
|
||||
new_code = await WriteCodeWithoutTools().run(context=context)
|
||||
new_code = new_code["code"]
|
||||
print(new_code)
|
||||
new_code = await WriteAnalysisCode().run(
|
||||
user_requirement=user_requirement,
|
||||
plan_status=plan_status,
|
||||
working_memory=working_memory,
|
||||
use_reflection=True,
|
||||
)
|
||||
assert "read_csv" in new_code # should correct read_excel to read_csv
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_code_reuse_code_simple():
|
||||
structural_context = """
|
||||
## User Requirement
|
||||
read a dataset test.csv and print its head
|
||||
## Current Plan
|
||||
[
|
||||
{
|
||||
"task_id": "1",
|
||||
"dependent_task_ids": [],
|
||||
"instruction": "import pandas and load the dataset from 'test.csv'.",
|
||||
"task_type": "",
|
||||
"code": "import pandas as pd\ndata = pd.read_csv('test.csv')",
|
||||
"result": "",
|
||||
"is_finished": true
|
||||
},
|
||||
{
|
||||
"task_id": "2",
|
||||
"dependent_task_ids": [
|
||||
"1"
|
||||
],
|
||||
"instruction": "Print the head of the dataset to display the first few rows.",
|
||||
"task_type": "",
|
||||
"code": "",
|
||||
"result": "",
|
||||
"is_finished": false
|
||||
}
|
||||
]
|
||||
## Current Task
|
||||
{"task_id": "2", "dependent_task_ids": ["1"], "instruction": "Print the head of the dataset to display the first few rows.", "task_type": "", "code": "", "result": "", "is_finished": false}
|
||||
"""
|
||||
context = [
|
||||
Message(content=structural_context, role="user"),
|
||||
]
|
||||
code = await WriteCodeWithoutTools().run(context=context)
|
||||
code = code["code"]
|
||||
print(code)
|
||||
assert "pandas" not in code and "read_csv" not in code # should reuse import and read statement from previous one
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_code_reuse_code_long():
|
||||
"""test code reuse for long context"""
|
||||
|
||||
structural_context = """
|
||||
## User Requirement
|
||||
Run data analysis on sklearn Iris dataset, include a plot
|
||||
## Current Plan
|
||||
[
|
||||
{
|
||||
"task_id": "1",
|
||||
"dependent_task_ids": [],
|
||||
"instruction": "Load the Iris dataset from sklearn.",
|
||||
"task_type": "",
|
||||
"code": "from sklearn.datasets import load_iris\niris_data = load_iris()\niris_data['data'][0:5], iris_data['target'][0:5]",
|
||||
"result": "(array([[5.1, 3.5, 1.4, 0.2],\n [4.9, 3. , 1.4, 0.2],\n [4.7, 3.2, 1.3, 0.2],\n [4.6, 3.1, 1.5, 0.2],\n [5. , 3.6, 1.4, 0.2]]),\n array([0, 0, 0, 0, 0]))",
|
||||
"is_finished": true
|
||||
},
|
||||
{
|
||||
"task_id": "2",
|
||||
"dependent_task_ids": [
|
||||
"1"
|
||||
],
|
||||
"instruction": "Perform exploratory data analysis on the Iris dataset.",
|
||||
"task_type": "",
|
||||
"code": "",
|
||||
"result": "",
|
||||
"is_finished": false
|
||||
},
|
||||
{
|
||||
"task_id": "3",
|
||||
"dependent_task_ids": [
|
||||
"2"
|
||||
],
|
||||
"instruction": "Create a plot visualizing the Iris dataset features.",
|
||||
"task_type": "",
|
||||
"code": "",
|
||||
"result": "",
|
||||
"is_finished": false
|
||||
}
|
||||
]
|
||||
## Current Task
|
||||
{"task_id": "2", "dependent_task_ids": ["1"], "instruction": "Perform exploratory data analysis on the Iris dataset.", "task_type": "", "code": "", "result": "", "is_finished": false}
|
||||
"""
|
||||
context = [
|
||||
Message(content=structural_context, role="user"),
|
||||
]
|
||||
trials_num = 5
|
||||
trials = [WriteCodeWithoutTools().run(context=context, temperature=0.0) for _ in range(trials_num)]
|
||||
trial_results = await asyncio.gather(*trials)
|
||||
print(*trial_results, sep="\n\n***\n\n")
|
||||
success = [
|
||||
"load_iris" not in result["code"] and "iris_data" in result["code"] for result in trial_results
|
||||
] # should reuse iris_data from previous tasks
|
||||
success_rate = sum(success) / trials_num
|
||||
logger.info(f"success rate: {success_rate :.2f}")
|
||||
assert success_rate >= 0.8
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_code_reuse_code_long_for_wine():
|
||||
"""test code reuse for long context"""
|
||||
|
||||
structural_context = """
|
||||
## User Requirement
|
||||
Run data analysis on sklearn Wisconsin Breast Cancer dataset, include a plot, train a model to predict targets (20% as validation), and show validation accuracy
|
||||
## Current Plan
|
||||
[
|
||||
{
|
||||
"task_id": "1",
|
||||
"dependent_task_ids": [],
|
||||
"instruction": "Load the sklearn Wine recognition dataset and perform exploratory data analysis."
|
||||
"task_type": "",
|
||||
"code": "from sklearn.datasets import load_wine\n# Load the Wine recognition dataset\nwine_data = load_wine()\n# Perform exploratory data analysis\nwine_data.keys()",
|
||||
"result": "Truncated to show only the last 1000 characters\ndict_keys(['data', 'target', 'frame', 'target_names', 'DESCR', 'feature_names'])",
|
||||
"is_finished": true
|
||||
},
|
||||
{
|
||||
"task_id": "2",
|
||||
"dependent_task_ids": ["1"],
|
||||
"instruction": "Create a plot to visualize some aspect of the wine dataset."
|
||||
"task_type": "",
|
||||
"code": "",
|
||||
"result": "",
|
||||
"is_finished": false
|
||||
},
|
||||
{
|
||||
"task_id": "3",
|
||||
"dependent_task_ids": ["1"],
|
||||
"instruction": "Split the dataset into training and validation sets with a 20% validation size.",
|
||||
"task_type": "",
|
||||
"code": "",
|
||||
"result": "",
|
||||
"is_finished": false
|
||||
},
|
||||
{
|
||||
"task_id": "4",
|
||||
"dependent_task_ids": ["3"],
|
||||
"instruction": "Train a model on the training set to predict wine class.",
|
||||
"task_type": "",
|
||||
"code": "",
|
||||
"result": "",
|
||||
"is_finished": false
|
||||
},
|
||||
{
|
||||
"task_id": "5",
|
||||
"dependent_task_ids": ["4"],
|
||||
"instruction": "Evaluate the model on the validation set and report the accuracy.",
|
||||
"task_type": "",
|
||||
"code": "",
|
||||
"result": "",
|
||||
"is_finished": false
|
||||
}
|
||||
]
|
||||
## Current Task
|
||||
{"task_id": "2", "dependent_task_ids": ["1"], "instruction": "Create a plot to visualize some aspect of the Wine dataset.", "task_type": "", "code": "", "result": "", "is_finished": false}
|
||||
"""
|
||||
context = [
|
||||
Message(content=structural_context, role="user"),
|
||||
]
|
||||
trials_num = 5
|
||||
trials = [WriteCodeWithoutTools().run(context=context, temperature=0.0) for _ in range(trials_num)]
|
||||
trial_results = await asyncio.gather(*trials)
|
||||
print(*trial_results, sep="\n\n***\n\n")
|
||||
success = [
|
||||
"load_wine" not in result["code"] and "wine_data" in result["code"] for result in trial_results
|
||||
] # should reuse iris_data from previous tasks
|
||||
success_rate = sum(success) / trials_num
|
||||
logger.info(f"success rate: {success_rate :.2f}")
|
||||
assert success_rate >= 0.8
|
||||
|
|
|
|||
|
|
@ -23,12 +23,10 @@ def test_precheck_update_plan_from_rsp():
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("use_tools", [(False), (True)])
|
||||
async def test_write_plan(use_tools):
|
||||
async def test_write_plan():
|
||||
rsp = await WritePlan().run(
|
||||
context=[Message("run analysis on sklearn iris dataset", role="user")], use_tools=use_tools
|
||||
context=[Message("Run data analysis on sklearn Iris dataset, include a plot", role="user")]
|
||||
)
|
||||
|
||||
assert "task_id" in rsp
|
||||
assert "instruction" in rsp
|
||||
assert "json" not in rsp # the output should be the content inside ```json ```
|
||||
|
|
|
|||
|
|
@ -14,7 +14,6 @@ from metagpt.actions.rebuild_class_view import RebuildClassView
|
|||
from metagpt.llm import LLM
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.asyncio
|
||||
async def test_rebuild(context):
|
||||
action = RebuildClassView(
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ from metagpt.utils.git_repository import ChangeType
|
|||
from metagpt.utils.graph_repository import SPO
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.asyncio
|
||||
async def test_rebuild(context, mocker):
|
||||
# Mock
|
||||
|
|
@ -47,6 +48,8 @@ async def test_rebuild(context, mocker):
|
|||
context=context,
|
||||
)
|
||||
await action.run()
|
||||
rows = await action.graph_db.select()
|
||||
assert rows
|
||||
assert context.repo.docs.graph_repo.changed_files
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,14 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the unittest of MincraftExtEnv
|
||||
|
||||
|
||||
from metagpt.environment.mincraft_env.const import MC_CKPT_DIR
|
||||
from metagpt.environment.mincraft_env.mincraft_ext_env import MincraftExtEnv
|
||||
|
||||
|
||||
def test_mincraft_ext_env():
|
||||
ext_env = MincraftExtEnv()
|
||||
assert ext_env.server, f"{ext_env.server_host}:{ext_env.server_port}"
|
||||
assert MC_CKPT_DIR.joinpath("skill/code").exists()
|
||||
assert ext_env.warm_up.get("optional_inventory_items") == 7
|
||||
|
|
@ -0,0 +1,14 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the unittest of MinecraftExtEnv
|
||||
|
||||
|
||||
from metagpt.environment.minecraft_env.const import MC_CKPT_DIR
|
||||
from metagpt.environment.minecraft_env.minecraft_ext_env import MinecraftExtEnv
|
||||
|
||||
|
||||
def test_minecraft_ext_env():
|
||||
ext_env = MinecraftExtEnv()
|
||||
assert ext_env.server, f"{ext_env.server_host}:{ext_env.server_port}"
|
||||
assert MC_CKPT_DIR.joinpath("skill/code").exists()
|
||||
assert ext_env.warm_up.get("optional_inventory_items") == 7
|
||||
|
|
@ -11,6 +11,7 @@ import pytest
|
|||
from metagpt.configs.llm_config import LLMConfig
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.schema import Message
|
||||
from tests.metagpt.provider.mock_llm_config import mock_llm_config
|
||||
from tests.metagpt.provider.req_resp_const import (
|
||||
default_resp_cont,
|
||||
get_part_chat_completion,
|
||||
|
|
@ -22,7 +23,7 @@ name = "GPT"
|
|||
|
||||
class MockBaseLLM(BaseLLM):
|
||||
def __init__(self, config: LLMConfig = None):
|
||||
pass
|
||||
self.config = config or mock_llm_config
|
||||
|
||||
def completion(self, messages: list[dict], timeout=3):
|
||||
return get_part_chat_completion(name)
|
||||
|
|
|
|||
|
|
@ -10,10 +10,9 @@ async def test_interpreter(mocker, auto_run):
|
|||
mocker.patch("metagpt.actions.di.execute_nb_code.ExecuteNbCode.run", return_value=("a successful run", True))
|
||||
mocker.patch("builtins.input", return_value="confirm")
|
||||
|
||||
requirement = "Run data analysis on sklearn Iris dataset, include a plot"
|
||||
tools = []
|
||||
requirement = "Run data analysis on sklearn Wine recognition dataset, include a plot, and train a model to predict wine class (20% as validation), and show validation accuracy."
|
||||
|
||||
di = DataInterpreter(auto_run=auto_run, use_tools=True, tools=tools)
|
||||
di = DataInterpreter(auto_run=auto_run)
|
||||
rsp = await di.run(requirement)
|
||||
logger.info(rsp)
|
||||
assert len(rsp.content) > 0
|
||||
|
|
@ -21,3 +20,15 @@ async def test_interpreter(mocker, auto_run):
|
|||
finished_tasks = di.planner.plan.get_finished_tasks()
|
||||
assert len(finished_tasks) > 0
|
||||
assert len(finished_tasks[0].code) > 0 # check one task to see if code is recorded
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_interpreter_react_mode(mocker):
|
||||
mocker.patch("metagpt.actions.di.execute_nb_code.ExecuteNbCode.run", return_value=("a successful run", True))
|
||||
|
||||
requirement = "Run data analysis on sklearn Wine recognition dataset, include a plot, and train a model to predict wine class (20% as validation), and show validation accuracy."
|
||||
|
||||
di = DataInterpreter(react_mode="react")
|
||||
rsp = await di.run(requirement)
|
||||
logger.info(rsp)
|
||||
assert len(rsp.content) > 0
|
||||
|
|
|
|||
|
|
@ -1,90 +0,0 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.actions.di.execute_nb_code import ExecuteNbCode
|
||||
from metagpt.logs import logger
|
||||
from metagpt.roles.di.ml_engineer import MLEngineer
|
||||
from metagpt.schema import Message, Plan, Task
|
||||
from metagpt.tools.tool_type import ToolType
|
||||
from tests.metagpt.actions.di.test_debug_code import CODE, DebugContext, ErrorStr
|
||||
|
||||
|
||||
def test_mle_init():
|
||||
mle = MLEngineer(goal="test", auto_run=True, use_tools=True, tools=["tool1", "tool2"])
|
||||
assert mle.tools == []
|
||||
|
||||
|
||||
MockPlan = Plan(
|
||||
goal="This is a titanic passenger survival dataset, your goal is to predict passenger survival outcome. The target column is Survived. Perform data analysis, data preprocessing, feature engineering, and modeling to predict the target. Report accuracy on the eval data. Train data path: 'tests/data/ml_datasets/titanic/split_train.csv', eval data path: 'tests/data/ml_datasets/titanic/split_eval.csv'.",
|
||||
context="",
|
||||
tasks=[
|
||||
Task(
|
||||
task_id="1",
|
||||
dependent_task_ids=[],
|
||||
instruction="Perform exploratory data analysis on the train dataset to understand the features and target variable.",
|
||||
task_type="eda",
|
||||
code="",
|
||||
result="",
|
||||
is_success=False,
|
||||
is_finished=False,
|
||||
)
|
||||
],
|
||||
task_map={
|
||||
"1": Task(
|
||||
task_id="1",
|
||||
dependent_task_ids=[],
|
||||
instruction="Perform exploratory data analysis on the train dataset to understand the features and target variable.",
|
||||
task_type="eda",
|
||||
code="",
|
||||
result="",
|
||||
is_success=False,
|
||||
is_finished=False,
|
||||
)
|
||||
},
|
||||
current_task_id="1",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mle_write_code(mocker):
|
||||
data_path = "tests/data/ml_datasets/titanic"
|
||||
|
||||
mle = MLEngineer(auto_run=True, use_tools=True)
|
||||
mle.planner.plan = MockPlan
|
||||
|
||||
code, _ = await mle._write_code()
|
||||
assert data_path in code["code"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mle_update_data_columns(mocker):
|
||||
mle = MLEngineer(auto_run=True, use_tools=True)
|
||||
mle.planner.plan = MockPlan
|
||||
|
||||
# manually update task type to test update
|
||||
mle.planner.plan.current_task.task_type = ToolType.DATA_PREPROCESS.value
|
||||
|
||||
result = await mle._update_data_columns()
|
||||
assert result is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mle_debug_code(mocker):
|
||||
mle = MLEngineer(auto_run=True, use_tools=True)
|
||||
mle.working_memory.add(Message(content=ErrorStr, cause_by=ExecuteNbCode))
|
||||
mle.latest_code = CODE
|
||||
mle.debug_context = DebugContext
|
||||
code, _ = await mle._write_code()
|
||||
assert len(code) > 0
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.asyncio
|
||||
async def test_ml_engineer():
|
||||
data_path = "tests/data/ml_datasets/titanic"
|
||||
requirement = f"This is a titanic passenger survival dataset, your goal is to predict passenger survival outcome. The target column is Survived. Perform data analysis, data preprocessing, feature engineering, and modeling to predict the target. Report accuracy on the eval data. Train data path: '{data_path}/split_train.csv', eval data path: '{data_path}/split_eval.csv'."
|
||||
tools = ["FillMissingValue", "CatCross", "dummy_tool"]
|
||||
|
||||
mle = MLEngineer(auto_run=True, use_tools=True, tools=tools)
|
||||
rsp = await mle.run(requirement)
|
||||
logger.info(rsp)
|
||||
assert len(rsp.content) > 0
|
||||
|
|
@ -6,11 +6,11 @@
|
|||
@File : test_tutorial_assistant.py
|
||||
"""
|
||||
|
||||
import aiofiles
|
||||
import pytest
|
||||
|
||||
from metagpt.const import TUTORIAL_PATH
|
||||
from metagpt.roles.tutorial_assistant import TutorialAssistant
|
||||
from metagpt.utils.common import aread
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -20,9 +20,8 @@ async def test_tutorial_assistant(language: str, topic: str, context):
|
|||
msg = await role.run(topic)
|
||||
assert TUTORIAL_PATH.exists()
|
||||
filename = msg.content
|
||||
async with aiofiles.open(filename, mode="r", encoding="utf-8") as reader:
|
||||
content = await reader.read()
|
||||
assert "pip" in content
|
||||
content = await aread(filename=filename)
|
||||
assert "pip" in content
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
37
tests/metagpt/strategy/test_planner.py
Normal file
37
tests/metagpt/strategy/test_planner.py
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
from metagpt.schema import Plan, Task
|
||||
from metagpt.strategy.planner import Planner
|
||||
from metagpt.strategy.task_type import TaskType
|
||||
|
||||
MOCK_TASK_MAP = {
|
||||
"1": Task(
|
||||
task_id="1",
|
||||
instruction="test instruction for finished task",
|
||||
task_type=TaskType.EDA.type_name,
|
||||
dependent_task_ids=[],
|
||||
code="some finished test code",
|
||||
result="some finished test result",
|
||||
is_finished=True,
|
||||
),
|
||||
"2": Task(
|
||||
task_id="2",
|
||||
instruction="test instruction for current task",
|
||||
task_type=TaskType.DATA_PREPROCESS.type_name,
|
||||
dependent_task_ids=["1"],
|
||||
),
|
||||
}
|
||||
MOCK_PLAN = Plan(
|
||||
goal="test goal",
|
||||
tasks=list(MOCK_TASK_MAP.values()),
|
||||
task_map=MOCK_TASK_MAP,
|
||||
current_task_id="2",
|
||||
)
|
||||
|
||||
|
||||
def test_planner_get_plan_status():
|
||||
planner = Planner(plan=MOCK_PLAN)
|
||||
status = planner.get_plan_status()
|
||||
|
||||
assert "some finished test code" in status
|
||||
assert "some finished test result" in status
|
||||
assert "test instruction for current task" in status
|
||||
assert TaskType.DATA_PREPROCESS.value.guidance in status # current task guidance
|
||||
|
|
@ -60,18 +60,24 @@ async def test_generate_webpages(mock_webpage_filename_with_styles_and_scripts,
|
|||
async def test_save_webpages_with_styles_and_scripts(mock_webpage_filename_with_styles_and_scripts, image_path):
|
||||
generator = GPTvGenerator()
|
||||
webpages = await generator.generate_webpages(image_path)
|
||||
webpages_dir = generator.save_webpages(image_path=image_path, webpages=webpages)
|
||||
webpages_dir = generator.save_webpages(webpages=webpages, save_folder_name="test_1")
|
||||
logs.logger.info(webpages_dir)
|
||||
assert webpages_dir.exists()
|
||||
assert (webpages_dir / "index.html").exists()
|
||||
assert (webpages_dir / "styles.css").exists()
|
||||
assert (webpages_dir / "scripts.js").exists()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_webpages_with_style_and_script(mock_webpage_filename_with_style_and_script, image_path):
|
||||
generator = GPTvGenerator()
|
||||
webpages = await generator.generate_webpages(image_path)
|
||||
webpages_dir = generator.save_webpages(image_path=image_path, webpages=webpages)
|
||||
webpages_dir = generator.save_webpages(webpages=webpages, save_folder_name="test_2")
|
||||
logs.logger.info(webpages_dir)
|
||||
assert webpages_dir.exists()
|
||||
assert (webpages_dir / "index.html").exists()
|
||||
assert (webpages_dir / "style.css").exists()
|
||||
assert (webpages_dir / "script.js").exists()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
|
|||
|
|
@ -4,8 +4,8 @@ from metagpt.tools.libs.web_scraping import scrape_web_playwright
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scrape_web_playwright():
|
||||
test_url = "https://www.deepwisdom.ai"
|
||||
async def test_scrape_web_playwright(http_server):
|
||||
server, test_url = await http_server()
|
||||
|
||||
result = await scrape_web_playwright(test_url)
|
||||
|
||||
|
|
@ -21,3 +21,4 @@ async def test_scrape_web_playwright():
|
|||
assert not result["inner_text"].endswith(" ")
|
||||
assert not result["html"].startswith(" ")
|
||||
assert not result["html"].endswith(" ")
|
||||
await server.stop()
|
||||
|
|
|
|||
|
|
@ -11,7 +11,6 @@ from typing import Callable
|
|||
|
||||
import pytest
|
||||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.configs.search_config import SearchConfig
|
||||
from metagpt.logs import logger
|
||||
from metagpt.tools import SearchEngineType
|
||||
|
|
@ -53,14 +52,11 @@ async def test_search_engine(
|
|||
search_engine_config = {"engine": search_engine_type, "run_func": run_func}
|
||||
|
||||
if search_engine_type is SearchEngineType.SERPAPI_GOOGLE:
|
||||
assert config.search
|
||||
search_engine_config["api_key"] = "mock-serpapi-key"
|
||||
elif search_engine_type is SearchEngineType.DIRECT_GOOGLE:
|
||||
assert config.search
|
||||
search_engine_config["api_key"] = "mock-google-key"
|
||||
search_engine_config["cse_id"] = "mock-google-cse"
|
||||
elif search_engine_type is SearchEngineType.SERPER_GOOGLE:
|
||||
assert config.search
|
||||
search_engine_config["api_key"] = "mock-serper-key"
|
||||
|
||||
async def test(search_engine):
|
||||
|
|
|
|||
|
|
@ -1,44 +1,8 @@
|
|||
from typing import Literal, Union
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from metagpt.tools.tool_convert import convert_code_to_tool_schema, docstring_to_schema
|
||||
|
||||
|
||||
def test_docstring_to_schema():
|
||||
docstring = """
|
||||
Some test desc.
|
||||
|
||||
Args:
|
||||
features (list): Columns to be processed.
|
||||
strategy (str, optional): The imputation strategy, notice 'mean' and 'median' can only be
|
||||
used for numeric features. Enum: ['mean', 'median', 'most_frequent', 'constant']. Defaults to 'mean'.
|
||||
fill_value (int, optional): Fill_value is used to replace all occurrences of missing_values.
|
||||
Defaults to None.
|
||||
Returns:
|
||||
pd.DataFrame: The transformed DataFrame.
|
||||
"""
|
||||
expected = {
|
||||
"description": "Some test desc.",
|
||||
"parameters": {
|
||||
"properties": {
|
||||
"features": {"type": "list", "description": "Columns to be processed."},
|
||||
"strategy": {
|
||||
"type": "str",
|
||||
"description": "The imputation strategy, notice 'mean' and 'median' can only be used for numeric features. Enum: ['mean', 'median', 'most_frequent', 'constant']. Defaults to 'mean'.",
|
||||
"default": "'mean'",
|
||||
"enum": ["'mean'", "'median'", "'most_frequent'", "'constant'"],
|
||||
},
|
||||
"fill_value": {
|
||||
"type": "int",
|
||||
"description": "Fill_value is used to replace all occurrences of missing_values. Defaults to None.",
|
||||
"default": "None",
|
||||
},
|
||||
},
|
||||
"required": ["features"],
|
||||
},
|
||||
"returns": [{"type": "pd.DataFrame", "description": "The transformed DataFrame."}],
|
||||
}
|
||||
schema = docstring_to_schema(docstring)
|
||||
assert schema == expected
|
||||
from metagpt.tools.tool_convert import convert_code_to_tool_schema
|
||||
|
||||
|
||||
class DummyClass:
|
||||
|
|
@ -81,12 +45,25 @@ class DummyClass:
|
|||
pass
|
||||
|
||||
|
||||
def dummy_fn(df: pd.DataFrame) -> dict:
|
||||
def dummy_fn(
|
||||
df: pd.DataFrame,
|
||||
s: str,
|
||||
k: int = 5,
|
||||
type: Literal["a", "b", "c"] = "a",
|
||||
test_dict: dict[str, int] = None,
|
||||
test_union: Union[str, list[str]] = "",
|
||||
) -> dict:
|
||||
"""
|
||||
Analyzes a DataFrame and categorizes its columns based on data types.
|
||||
|
||||
Args:
|
||||
df (pd.DataFrame): The DataFrame to be analyzed.
|
||||
df: The DataFrame to be analyzed.
|
||||
Another line for df.
|
||||
s (str): Some test string param.
|
||||
Another line for s.
|
||||
k (int, optional): Some test integer param. Defaults to 5.
|
||||
type (Literal["a", "b", "c"], optional): Some test type. Defaults to 'a'.
|
||||
more_args: will be omitted here for testing
|
||||
|
||||
Returns:
|
||||
dict: A dictionary with four keys ('Category', 'Numeric', 'Datetime', 'Others').
|
||||
|
|
@ -115,41 +92,21 @@ def test_convert_code_to_tool_schema_class():
|
|||
"methods": {
|
||||
"__init__": {
|
||||
"type": "function",
|
||||
"description": "Initialize self.",
|
||||
"parameters": {
|
||||
"properties": {
|
||||
"features": {"type": "list", "description": "Columns to be processed."},
|
||||
"strategy": {
|
||||
"type": "str",
|
||||
"description": "The imputation strategy, notice 'mean' and 'median' can only be used for numeric features. Enum: ['mean', 'median', 'most_frequent', 'constant']. Defaults to 'mean'.",
|
||||
"default": "'mean'",
|
||||
"enum": ["'mean'", "'median'", "'most_frequent'", "'constant'"],
|
||||
},
|
||||
"fill_value": {
|
||||
"type": "int",
|
||||
"description": "Fill_value is used to replace all occurrences of missing_values. Defaults to None.",
|
||||
"default": "None",
|
||||
},
|
||||
},
|
||||
"required": ["features"],
|
||||
},
|
||||
"description": "Initialize self. ",
|
||||
"signature": "(self, features: list, strategy: str = 'mean', fill_value=None)",
|
||||
"parameters": "Args: features (list): Columns to be processed. strategy (str, optional): The imputation strategy, notice 'mean' and 'median' can only be used for numeric features. Enum: ['mean', 'median', 'most_frequent', 'constant']. Defaults to 'mean'. fill_value (int, optional): Fill_value is used to replace all occurrences of missing_values. Defaults to None.",
|
||||
},
|
||||
"fit": {
|
||||
"type": "function",
|
||||
"description": "Fit the FillMissingValue model.",
|
||||
"parameters": {
|
||||
"properties": {"df": {"type": "pd.DataFrame", "description": "The input DataFrame."}},
|
||||
"required": ["df"],
|
||||
},
|
||||
"description": "Fit the FillMissingValue model. ",
|
||||
"signature": "(self, df: pandas.core.frame.DataFrame)",
|
||||
"parameters": "Args: df (pd.DataFrame): The input DataFrame.",
|
||||
},
|
||||
"transform": {
|
||||
"type": "function",
|
||||
"description": "Transform the input DataFrame with the fitted model.",
|
||||
"parameters": {
|
||||
"properties": {"df": {"type": "pd.DataFrame", "description": "The input DataFrame."}},
|
||||
"required": ["df"],
|
||||
},
|
||||
"returns": [{"type": "pd.DataFrame", "description": "The transformed DataFrame."}],
|
||||
"description": "Transform the input DataFrame with the fitted model. ",
|
||||
"signature": "(self, df: pandas.core.frame.DataFrame) -> pandas.core.frame.DataFrame",
|
||||
"parameters": "Args: df (pd.DataFrame): The input DataFrame. Returns: pd.DataFrame: The transformed DataFrame.",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
@ -160,11 +117,9 @@ def test_convert_code_to_tool_schema_class():
|
|||
def test_convert_code_to_tool_schema_function():
|
||||
expected = {
|
||||
"type": "function",
|
||||
"description": "Analyzes a DataFrame and categorizes its columns based on data types.",
|
||||
"parameters": {
|
||||
"properties": {"df": {"type": "pd.DataFrame", "description": "The DataFrame to be analyzed."}},
|
||||
"required": ["df"],
|
||||
},
|
||||
"description": "Analyzes a DataFrame and categorizes its columns based on data types. ",
|
||||
"signature": "(df: pandas.core.frame.DataFrame, s: str, k: int = 5, type: Literal['a', 'b', 'c'] = 'a', test_dict: dict[str, int] = None, test_union: Union[str, list[str]] = '') -> dict",
|
||||
"parameters": "Args: df: The DataFrame to be analyzed. Another line for df. s (str): Some test string param. Another line for s. k (int, optional): Some test integer param. Defaults to 5. type (Literal[\"a\", \"b\", \"c\"], optional): Some test type. Defaults to 'a'. more_args: will be omitted here for testing Returns: dict: A dictionary with four keys ('Category', 'Numeric', 'Datetime', 'Others'). Each key corresponds to a list of column names belonging to that category.",
|
||||
}
|
||||
schema = convert_code_to_tool_schema(dummy_fn)
|
||||
assert schema == expected
|
||||
|
|
|
|||
90
tests/metagpt/tools/test_tool_recommend.py
Normal file
90
tests/metagpt/tools/test_tool_recommend.py
Normal file
|
|
@ -0,0 +1,90 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.schema import Plan, Task
|
||||
from metagpt.tools import TOOL_REGISTRY
|
||||
from metagpt.tools.tool_recommend import (
|
||||
BM25ToolRecommender,
|
||||
ToolRecommender,
|
||||
TypeMatchToolRecommender,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_plan(mocker):
|
||||
task_map = {
|
||||
"1": Task(
|
||||
task_id="1",
|
||||
instruction="conduct feature engineering, add new features on the dataset",
|
||||
task_type="feature engineering",
|
||||
)
|
||||
}
|
||||
plan = Plan(
|
||||
goal="test requirement",
|
||||
tasks=list(task_map.values()),
|
||||
task_map=task_map,
|
||||
current_task_id="1",
|
||||
)
|
||||
return plan
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_bm25_tr(mocker):
|
||||
tr = BM25ToolRecommender(tools=["FillMissingValue", "PolynomialExpansion", "web scraping"])
|
||||
return tr
|
||||
|
||||
|
||||
def test_tr_init():
|
||||
tr = ToolRecommender(tools=["FillMissingValue", "PolynomialExpansion", "web scraping", "non-existing tool"])
|
||||
# web_scraping is a tool tag, it has one tool scrape_web_playwright
|
||||
assert list(tr.tools.keys()) == [
|
||||
"FillMissingValue",
|
||||
"PolynomialExpansion",
|
||||
"scrape_web_playwright",
|
||||
]
|
||||
|
||||
|
||||
def test_tr_init_default_tools_value():
|
||||
tr = ToolRecommender()
|
||||
assert tr.tools == {}
|
||||
|
||||
|
||||
def test_tr_init_tools_all():
|
||||
tr = ToolRecommender(tools=["<all>"])
|
||||
assert list(tr.tools.keys()) == list(TOOL_REGISTRY.get_all_tools().keys())
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bm25_tr_recall_with_plan(mock_plan, mock_bm25_tr):
|
||||
result = await mock_bm25_tr.recall_tools(plan=mock_plan)
|
||||
assert len(result) == 3
|
||||
assert result[0].name == "PolynomialExpansion"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bm25_tr_recall_no_plan(mock_plan, mock_bm25_tr):
|
||||
result = await mock_bm25_tr.recall_tools(
|
||||
context="conduct feature engineering, add new features on the dataset", plan=None
|
||||
)
|
||||
assert len(result) == 3
|
||||
assert result[0].name == "PolynomialExpansion"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bm25_recommend_tools(mock_bm25_tr):
|
||||
result = await mock_bm25_tr.recommend_tools(context="conduct feature engineering, add new features on the dataset")
|
||||
assert len(result) == 2 # web scraping tool should be filtered out at rank stage
|
||||
assert result[0].name == "PolynomialExpansion"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_recommended_tool_info(mock_plan, mock_bm25_tr):
|
||||
result = await mock_bm25_tr.get_recommended_tool_info(plan=mock_plan)
|
||||
assert isinstance(result, str)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tm_tr_recall_with_plan(mock_plan, mock_bm25_tr):
|
||||
tr = TypeMatchToolRecommender(tools=["FillMissingValue", "PolynomialExpansion", "web scraping"])
|
||||
result = await tr.recall_tools(plan=mock_plan)
|
||||
assert len(result) == 1
|
||||
assert result[0].name == "PolynomialExpansion"
|
||||
|
|
@ -1,7 +1,6 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.tools.tool_registry import ToolRegistry
|
||||
from metagpt.tools.tool_type import ToolType
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -9,25 +8,11 @@ def tool_registry():
|
|||
return ToolRegistry()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tool_registry_full():
|
||||
return ToolRegistry(tool_types=ToolType)
|
||||
|
||||
|
||||
# Test Initialization
|
||||
def test_initialization(tool_registry):
|
||||
assert isinstance(tool_registry, ToolRegistry)
|
||||
assert tool_registry.tools == {}
|
||||
assert tool_registry.tool_types == {}
|
||||
assert tool_registry.tools_by_types == {}
|
||||
|
||||
|
||||
# Test Initialization with tool types
|
||||
def test_initialize_with_tool_types(tool_registry_full):
|
||||
assert isinstance(tool_registry_full, ToolRegistry)
|
||||
assert tool_registry_full.tools == {}
|
||||
assert tool_registry_full.tools_by_types == {}
|
||||
assert "data_preprocess" in tool_registry_full.tool_types
|
||||
assert tool_registry.tools_by_tags == {}
|
||||
|
||||
|
||||
class TestClassTool:
|
||||
|
|
@ -72,31 +57,24 @@ def test_get_tool(tool_registry):
|
|||
assert "description" in tool.schemas
|
||||
|
||||
|
||||
# Similar tests for has_tool_type, get_tool_type, get_tools_by_type
|
||||
def test_has_tool_type(tool_registry_full):
|
||||
assert tool_registry_full.has_tool_type("data_preprocess")
|
||||
assert not tool_registry_full.has_tool_type("NonexistentType")
|
||||
def test_has_tool_tag(tool_registry):
|
||||
tool_registry.register_tool(
|
||||
"TestClassTool", "/path/to/tool", tool_source_object=TestClassTool, tags=["machine learning", "test"]
|
||||
)
|
||||
assert tool_registry.has_tool_tag("test")
|
||||
assert not tool_registry.has_tool_tag("Non-existent tag")
|
||||
|
||||
|
||||
def test_get_tool_type(tool_registry_full):
|
||||
retrieved_type = tool_registry_full.get_tool_type("data_preprocess")
|
||||
assert retrieved_type is not None
|
||||
assert retrieved_type.name == "data_preprocess"
|
||||
|
||||
|
||||
def test_get_tools_by_type(tool_registry):
|
||||
tool_type_name = "TestType"
|
||||
def test_get_tools_by_tag(tool_registry):
|
||||
tool_tag_name = "Test Tag"
|
||||
tool_name = "TestTool"
|
||||
tool_path = "/path/to/tool"
|
||||
|
||||
tool_registry.register_tool(tool_name, tool_path, tool_type=tool_type_name, tool_source_object=TestClassTool)
|
||||
tool_registry.register_tool(tool_name, tool_path, tags=[tool_tag_name], tool_source_object=TestClassTool)
|
||||
|
||||
tools_by_type = tool_registry.get_tools_by_type(tool_type_name)
|
||||
assert tools_by_type is not None
|
||||
assert tool_name in tools_by_type
|
||||
tools_by_tag = tool_registry.get_tools_by_tag(tool_tag_name)
|
||||
assert tools_by_tag is not None
|
||||
assert tool_name in tools_by_tag
|
||||
|
||||
|
||||
# Test case for when the tool type does not exist
|
||||
def test_get_tools_by_nonexistent_type(tool_registry):
|
||||
tools_by_type = tool_registry.get_tools_by_type("NonexistentType")
|
||||
assert not tools_by_type
|
||||
tools_by_tag_non_existent = tool_registry.get_tools_by_tag("Non-existent Tag")
|
||||
assert not tools_by_tag_non_existent
|
||||
|
|
|
|||
|
|
@ -9,14 +9,16 @@ from metagpt.utils.parse_html import WebPage
|
|||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"browser_type, url, urls",
|
||||
"browser_type",
|
||||
[
|
||||
(WebBrowserEngineType.PLAYWRIGHT, "https://deepwisdom.ai", ("https://deepwisdom.ai",)),
|
||||
(WebBrowserEngineType.SELENIUM, "https://deepwisdom.ai", ("https://deepwisdom.ai",)),
|
||||
WebBrowserEngineType.PLAYWRIGHT,
|
||||
WebBrowserEngineType.SELENIUM,
|
||||
],
|
||||
ids=["playwright", "selenium"],
|
||||
)
|
||||
async def test_scrape_web_page(browser_type, url, urls):
|
||||
async def test_scrape_web_page(browser_type, http_server):
|
||||
server, url = await http_server()
|
||||
urls = [url, url, url]
|
||||
browser = web_browser_engine.WebBrowserEngine(engine=browser_type)
|
||||
result = await browser.run(url)
|
||||
assert isinstance(result, WebPage)
|
||||
|
|
@ -27,6 +29,7 @@ async def test_scrape_web_page(browser_type, url, urls):
|
|||
assert isinstance(results, list)
|
||||
assert len(results) == len(urls) + 1
|
||||
assert all(("MetaGPT" in i.inner_text) for i in results)
|
||||
await server.stop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -9,18 +9,28 @@ from metagpt.utils.parse_html import WebPage
|
|||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"browser_type, use_proxy, kwagrs, url, urls",
|
||||
"browser_type, use_proxy, kwagrs,",
|
||||
[
|
||||
("chromium", {"proxy": True}, {}, "https://www.deepwisdom.ai", ("https://www.deepwisdom.ai",)),
|
||||
("firefox", {}, {"ignore_https_errors": True}, "https://www.deepwisdom.ai", ("https://www.deepwisdom.ai",)),
|
||||
("webkit", {}, {"ignore_https_errors": True}, "https://www.deepwisdom.ai", ("https://www.deepwisdom.ai",)),
|
||||
("chromium", {"proxy": True}, {}),
|
||||
(
|
||||
"firefox",
|
||||
{},
|
||||
{"ignore_https_errors": True},
|
||||
),
|
||||
(
|
||||
"webkit",
|
||||
{},
|
||||
{"ignore_https_errors": True},
|
||||
),
|
||||
],
|
||||
ids=["chromium-normal", "firefox-normal", "webkit-normal"],
|
||||
)
|
||||
async def test_scrape_web_page(browser_type, use_proxy, kwagrs, url, urls, proxy, capfd):
|
||||
async def test_scrape_web_page(browser_type, use_proxy, kwagrs, proxy, capfd, http_server):
|
||||
server, url = await http_server()
|
||||
urls = [url, url, url]
|
||||
proxy_url = None
|
||||
if use_proxy:
|
||||
server, proxy_url = await proxy()
|
||||
proxy_server, proxy_url = await proxy()
|
||||
browser = web_browser_engine_playwright.PlaywrightWrapper(browser_type=browser_type, proxy=proxy_url, **kwagrs)
|
||||
result = await browser.run(url)
|
||||
assert isinstance(result, WebPage)
|
||||
|
|
@ -32,8 +42,10 @@ async def test_scrape_web_page(browser_type, use_proxy, kwagrs, url, urls, proxy
|
|||
assert len(results) == len(urls) + 1
|
||||
assert all(("MetaGPT" in i.inner_text) for i in results)
|
||||
if use_proxy:
|
||||
server.close()
|
||||
proxy_server.close()
|
||||
await proxy_server.wait_closed()
|
||||
assert "Proxy:" in capfd.readouterr().out
|
||||
await server.stop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
|
||||
import browsers
|
||||
import pytest
|
||||
|
||||
|
|
@ -10,51 +11,48 @@ from metagpt.utils.parse_html import WebPage
|
|||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"browser_type, use_proxy, url, urls",
|
||||
"browser_type, use_proxy,",
|
||||
[
|
||||
pytest.param(
|
||||
"chrome",
|
||||
True,
|
||||
"https://deepwisdom.ai",
|
||||
("https://deepwisdom.ai",),
|
||||
False,
|
||||
marks=pytest.mark.skipif(not browsers.get("chrome"), reason="chrome browser not found"),
|
||||
),
|
||||
pytest.param(
|
||||
"firefox",
|
||||
False,
|
||||
"https://deepwisdom.ai",
|
||||
("https://deepwisdom.ai",),
|
||||
marks=pytest.mark.skipif(not browsers.get("firefox"), reason="firefox browser not found"),
|
||||
),
|
||||
pytest.param(
|
||||
"edge",
|
||||
False,
|
||||
"https://deepwisdom.ai",
|
||||
("https://deepwisdom.ai",),
|
||||
marks=pytest.mark.skipif(not browsers.get("msedge"), reason="edge browser not found"),
|
||||
),
|
||||
],
|
||||
ids=["chrome-normal", "firefox-normal", "edge-normal"],
|
||||
)
|
||||
async def test_scrape_web_page(browser_type, use_proxy, url, urls, proxy, capfd):
|
||||
async def test_scrape_web_page(browser_type, use_proxy, proxy, capfd, http_server):
|
||||
# Prerequisites
|
||||
# firefox, chrome, Microsoft Edge
|
||||
server, url = await http_server()
|
||||
urls = [url, url, url]
|
||||
proxy_url = None
|
||||
if use_proxy:
|
||||
server, proxy_url = await proxy()
|
||||
proxy_server, proxy_url = await proxy()
|
||||
browser = web_browser_engine_selenium.SeleniumWrapper(browser_type=browser_type, proxy=proxy_url)
|
||||
result = await browser.run(url)
|
||||
assert isinstance(result, WebPage)
|
||||
assert "MetaGPT" in result.inner_text
|
||||
|
||||
if urls:
|
||||
results = await browser.run(url, *urls)
|
||||
assert isinstance(results, list)
|
||||
assert len(results) == len(urls) + 1
|
||||
assert all(("MetaGPT" in i.inner_text) for i in results)
|
||||
results = await browser.run(url, *urls)
|
||||
assert isinstance(results, list)
|
||||
assert len(results) == len(urls) + 1
|
||||
assert all(("MetaGPT" in i.inner_text) for i in results)
|
||||
if use_proxy:
|
||||
server.close()
|
||||
assert "Proxy:" in capfd.readouterr().out
|
||||
proxy_server.close()
|
||||
await proxy_server.wait_closed()
|
||||
assert "Proxy: localhost" in capfd.readouterr().out
|
||||
await server.stop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -13,7 +13,6 @@ import uuid
|
|||
from pathlib import Path
|
||||
from typing import Any, Set
|
||||
|
||||
import aiofiles
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
|
@ -125,9 +124,7 @@ class TestGetProjectRoot:
|
|||
async def test_parse_data_exception(self, filename, want):
|
||||
pathname = Path(__file__).parent.parent.parent / "data/output_parser" / filename
|
||||
assert pathname.exists()
|
||||
async with aiofiles.open(str(pathname), mode="r") as reader:
|
||||
data = await reader.read()
|
||||
|
||||
data = await aread(filename=pathname)
|
||||
result = OutputParser.parse_data(data=data)
|
||||
assert want in result
|
||||
|
||||
|
|
@ -198,12 +195,25 @@ class TestGetProjectRoot:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_write(self):
|
||||
pathname = Path(__file__).parent / uuid.uuid4().hex / "test.tmp"
|
||||
pathname = Path(__file__).parent / f"../../../workspace/unittest/{uuid.uuid4().hex}" / "test.tmp"
|
||||
await awrite(pathname, "ABC")
|
||||
data = await aread(pathname)
|
||||
assert data == "ABC"
|
||||
pathname.unlink(missing_ok=True)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_write_error_charset(self):
|
||||
pathname = Path(__file__).parent / f"../../../workspace/unittest/{uuid.uuid4().hex}" / "test.txt"
|
||||
content = "中国abc123\u27f6"
|
||||
await awrite(filename=pathname, data=content)
|
||||
data = await aread(filename=pathname)
|
||||
assert data == content
|
||||
|
||||
content = "GB18030 是中国国家标准局发布的新一代中文字符集标准,是 GBK 的升级版,支持更广泛的字符范围。"
|
||||
await awrite(filename=pathname, data=content, encoding="gb2312")
|
||||
data = await aread(filename=pathname, encoding="utf-8")
|
||||
assert data == content
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
|
|
@ -10,15 +10,14 @@
|
|||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import aiofiles
|
||||
import pytest
|
||||
|
||||
from metagpt.utils.common import awrite
|
||||
from metagpt.utils.git_repository import GitRepository
|
||||
|
||||
|
||||
async def mock_file(filename, content=""):
|
||||
async with aiofiles.open(str(filename), mode="w") as file:
|
||||
await file.write(content)
|
||||
await awrite(filename=filename, data=content)
|
||||
|
||||
|
||||
async def mock_repo(local_path) -> (GitRepository, Path):
|
||||
|
|
|
|||
25
tests/metagpt/utils/test_repo_to_markdown.py
Normal file
25
tests/metagpt/utils/test_repo_to_markdown.py
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.utils.repo_to_markdown import repo_to_markdown
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["repo_path", "output"],
|
||||
[(Path(__file__).parent.parent, Path(__file__).parent.parent.parent / f"workspace/unittest/{uuid.uuid4().hex}.md")],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_repo_to_markdown(repo_path: Path, output: Path):
|
||||
markdown = await repo_to_markdown(repo_path=repo_path, output=output)
|
||||
assert output.exists()
|
||||
assert markdown
|
||||
|
||||
output.unlink(missing_ok=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
@ -9,7 +9,6 @@ import uuid
|
|||
from pathlib import Path
|
||||
|
||||
import aioboto3
|
||||
import aiofiles
|
||||
import pytest
|
||||
|
||||
from metagpt.config2 import Config
|
||||
|
|
@ -46,7 +45,7 @@ async def test_s3(mocker):
|
|||
conn = S3(s3)
|
||||
object_name = "unittest.bak"
|
||||
await conn.upload_file(bucket=s3.bucket, local_path=__file__, object_name=object_name)
|
||||
pathname = (Path(__file__).parent / uuid.uuid4().hex).with_suffix(".bak")
|
||||
pathname = (Path(__file__).parent / "../../../workspace/unittest" / uuid.uuid4().hex).with_suffix(".bak")
|
||||
pathname.unlink(missing_ok=True)
|
||||
await conn.download_file(bucket=s3.bucket, object_name=object_name, local_path=str(pathname))
|
||||
assert pathname.exists()
|
||||
|
|
@ -54,8 +53,7 @@ async def test_s3(mocker):
|
|||
assert url
|
||||
bin_data = await conn.get_object(bucket=s3.bucket, object_name=object_name)
|
||||
assert bin_data
|
||||
async with aiofiles.open(__file__, mode="r", encoding="utf-8") as reader:
|
||||
data = await reader.read()
|
||||
data = await aread(filename=__file__)
|
||||
res = await conn.cache(data, ".bak", "script")
|
||||
assert "http" in res
|
||||
|
||||
|
|
@ -69,8 +67,6 @@ async def test_s3(mocker):
|
|||
except Exception:
|
||||
pass
|
||||
|
||||
await reader.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ def _paragraphs(n):
|
|||
@pytest.mark.parametrize(
|
||||
"msgs, model_name, system_text, reserved, expected",
|
||||
[
|
||||
(_msgs(), "gpt-3.5-turbo", "System", 1500, 1),
|
||||
(_msgs(), "gpt-3.5-turbo-0613", "System", 1500, 1),
|
||||
(_msgs(), "gpt-3.5-turbo-16k", "System", 3000, 6),
|
||||
(_msgs(), "gpt-3.5-turbo-16k", "Hello," * 1000, 3000, 5),
|
||||
(_msgs(), "gpt-4", "System", 2000, 3),
|
||||
|
|
@ -32,22 +32,23 @@ def _paragraphs(n):
|
|||
],
|
||||
)
|
||||
def test_reduce_message_length(msgs, model_name, system_text, reserved, expected):
|
||||
assert len(reduce_message_length(msgs, model_name, system_text, reserved)) / (len("Hello,")) / 1000 == expected
|
||||
length = len(reduce_message_length(msgs, model_name, system_text, reserved)) / (len("Hello,")) / 1000
|
||||
assert length == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"text, prompt_template, model_name, system_text, reserved, expected",
|
||||
[
|
||||
(" ".join("Hello World." for _ in range(1000)), "Prompt: {}", "gpt-3.5-turbo", "System", 1500, 2),
|
||||
(" ".join("Hello World." for _ in range(1000)), "Prompt: {}", "gpt-3.5-turbo-0613", "System", 1500, 2),
|
||||
(" ".join("Hello World." for _ in range(1000)), "Prompt: {}", "gpt-3.5-turbo-16k", "System", 3000, 1),
|
||||
(" ".join("Hello World." for _ in range(4000)), "Prompt: {}", "gpt-4", "System", 2000, 2),
|
||||
(" ".join("Hello World." for _ in range(8000)), "Prompt: {}", "gpt-4-32k", "System", 4000, 1),
|
||||
(" ".join("Hello World" for _ in range(8000)), "Prompt: {}", "gpt-3.5-turbo", "System", 1000, 8),
|
||||
(" ".join("Hello World" for _ in range(8000)), "Prompt: {}", "gpt-3.5-turbo-0613", "System", 1000, 8),
|
||||
],
|
||||
)
|
||||
def test_generate_prompt_chunk(text, prompt_template, model_name, system_text, reserved, expected):
|
||||
ret = list(generate_prompt_chunk(text, prompt_template, model_name, system_text, reserved))
|
||||
assert len(ret) == expected
|
||||
chunk = len(list(generate_prompt_chunk(text, prompt_template, model_name, system_text, reserved)))
|
||||
assert chunk == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
|
|
|||
64
tests/metagpt/utils/test_tree.py
Normal file
64
tests/metagpt/utils/test_tree.py
Normal file
|
|
@ -0,0 +1,64 @@
|
|||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.utils.tree import _print_tree, tree
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("root", "rules"),
|
||||
[
|
||||
(str(Path(__file__).parent / "../.."), None),
|
||||
(str(Path(__file__).parent / "../.."), str(Path(__file__).parent / "../../../.gitignore")),
|
||||
],
|
||||
)
|
||||
def test_tree(root: str, rules: str):
|
||||
v = tree(root=root, gitignore=rules)
|
||||
assert v
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("root", "rules"),
|
||||
[
|
||||
(str(Path(__file__).parent / "../.."), None),
|
||||
(str(Path(__file__).parent / "../.."), str(Path(__file__).parent / "../../../.gitignore")),
|
||||
],
|
||||
)
|
||||
def test_tree_command(root: str, rules: str):
|
||||
v = tree(root=root, gitignore=rules, run_command=True)
|
||||
assert v
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("tree", "want"),
|
||||
[
|
||||
({"a": {"b": {}, "c": {}}}, ["a", "+-- b", "+-- c"]),
|
||||
({"a": {"b": {}, "c": {"d": {}}}}, ["a", "+-- b", "+-- c", " +-- d"]),
|
||||
(
|
||||
{"a": {"b": {"e": {"f": {}, "g": {}}}, "c": {"d": {}}}},
|
||||
["a", "+-- b", "| +-- e", "| +-- f", "| +-- g", "+-- c", " +-- d"],
|
||||
),
|
||||
(
|
||||
{"h": {"a": {"b": {"e": {"f": {}, "g": {}}}, "c": {"d": {}}}, "i": {}}},
|
||||
[
|
||||
"h",
|
||||
"+-- a",
|
||||
"| +-- b",
|
||||
"| | +-- e",
|
||||
"| | +-- f",
|
||||
"| | +-- g",
|
||||
"| +-- c",
|
||||
"| +-- d",
|
||||
"+-- i",
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test__print_tree(tree: dict, want: List[str]):
|
||||
v = _print_tree(tree)
|
||||
assert v == want
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
Loading…
Add table
Add a link
Reference in a new issue