mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-14 15:25:17 +02:00
feat: merge geekan:dev
This commit is contained in:
commit
739452edbb
92 changed files with 5372 additions and 197 deletions
|
|
@ -38,14 +38,14 @@ def rsp_cache():
|
|||
rsp_cache_file_path = TEST_DATA_PATH / "rsp_cache.json" # read repo-provided
|
||||
new_rsp_cache_file_path = TEST_DATA_PATH / "rsp_cache_new.json" # exporting a new copy
|
||||
if os.path.exists(rsp_cache_file_path):
|
||||
with open(rsp_cache_file_path, "r") as f1:
|
||||
with open(rsp_cache_file_path, "r", encoding="utf-8") as f1:
|
||||
rsp_cache_json = json.load(f1)
|
||||
else:
|
||||
rsp_cache_json = {}
|
||||
yield rsp_cache_json
|
||||
with open(rsp_cache_file_path, "w") as f2:
|
||||
with open(rsp_cache_file_path, "w", encoding="utf-8") as f2:
|
||||
json.dump(rsp_cache_json, f2, indent=4, ensure_ascii=False)
|
||||
with open(new_rsp_cache_file_path, "w") as f2:
|
||||
with open(new_rsp_cache_file_path, "w", encoding="utf-8") as f2:
|
||||
json.dump(RSP_CACHE_NEW, f2, indent=4, ensure_ascii=False)
|
||||
|
||||
|
||||
|
|
@ -64,6 +64,7 @@ def llm_mock(rsp_cache, mocker, request):
|
|||
llm.rsp_cache = rsp_cache
|
||||
mocker.patch("metagpt.provider.base_llm.BaseLLM.aask", llm.aask)
|
||||
mocker.patch("metagpt.provider.base_llm.BaseLLM.aask_batch", llm.aask_batch)
|
||||
mocker.patch("metagpt.provider.openai_api.OpenAILLM.aask_code", llm.aask_code)
|
||||
yield mocker
|
||||
if hasattr(request.node, "test_outcome") and request.node.test_outcome.passed:
|
||||
if llm.rsp_candidates:
|
||||
|
|
@ -71,7 +72,7 @@ def llm_mock(rsp_cache, mocker, request):
|
|||
cand_key = list(rsp_candidate.keys())[0]
|
||||
cand_value = list(rsp_candidate.values())[0]
|
||||
if cand_key not in llm.rsp_cache:
|
||||
logger.info(f"Added '{cand_key[:100]} ... -> {cand_value[:20]} ...' to response cache")
|
||||
logger.info(f"Added '{cand_key[:100]} ... -> {str(cand_value)[:20]} ...' to response cache")
|
||||
llm.rsp_cache.update(rsp_candidate)
|
||||
RSP_CACHE_NEW.update(rsp_candidate)
|
||||
|
||||
|
|
|
|||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
|
|
@ -0,0 +1,5 @@
|
|||
{"world_name": "the ville",
|
||||
"maze_width": 140,
|
||||
"maze_height": 100,
|
||||
"sq_tile_size": 32,
|
||||
"special_constraint": ""}
|
||||
|
|
@ -0,0 +1,63 @@
|
|||
32138, the Ville, artist's co-living space, Latoya Williams's room
|
||||
32148, the Ville, artist's co-living space, Latoya Williams's bathroom
|
||||
32158, the Ville, artist's co-living space, Rajiv Patel's room
|
||||
32168, the Ville, artist's co-living space, Rajiv Patel's bathroom
|
||||
32178, the Ville, artist's co-living space, Abigail Chen's room
|
||||
32188, the Ville, artist's co-living space, Abigail Chen's bathroom
|
||||
32198, the Ville, artist's co-living space, Francisco Lopez's room
|
||||
32139, the Ville, artist's co-living space, Francisco Lopez's bathroom
|
||||
32149, the Ville, artist's co-living space, Hailey Johnson's room
|
||||
32159, the Ville, artist's co-living space, Hailey Johnson's bathroom
|
||||
32179, the Ville, artist's co-living space, common room
|
||||
32189, the Ville, artist's co-living space, kitchen
|
||||
32199, the Ville, Arthur Burton's apartment, main room
|
||||
32140, the Ville, Arthur Burton's apartment, bathroom
|
||||
32150, the Ville, Ryan Park's apartment, main room
|
||||
32160, the Ville, Ryan Park's apartment, bathroom
|
||||
32170, the Ville, Isabella Rodriguez's apartment, main room
|
||||
32180, the Ville, Isabella Rodriguez's apartment, bathroom
|
||||
32190, the Ville, Giorgio Rossi's apartment, main room
|
||||
32200, the Ville, Giorgio Rossi's apartment, bathroom
|
||||
32141, the Ville, Carlos Gomez's apartment, main room
|
||||
32151, the Ville, Carlos Gomez's apartment, bathroom
|
||||
32161, the Ville, The Rose and Crown Pub, pub
|
||||
32171, the Ville, Hobbs Cafe, cafe
|
||||
32181, the Ville, Oak Hill College, classroom
|
||||
32191, the Ville, Oak Hill College, library
|
||||
32201, the Ville, Oak Hill College, hallway
|
||||
32142, the Ville, Johnson Park, park
|
||||
32152, the Ville, Harvey Oak Supply Store, supply store
|
||||
32162, the Ville, The Willows Market and Pharmacy, store
|
||||
32193, the Ville, Adam Smith's house, main room
|
||||
32203, the Ville, Adam Smith's house, bathroom
|
||||
32174, the Ville, Yuriko Yamamoto's house, main room
|
||||
32184, the Ville, Yuriko Yamamoto's house, bathroom
|
||||
32194, the Ville, Moore family's house, main room
|
||||
32204, the Ville, Moore family's house, bathroom
|
||||
32172, the Ville, Dorm for Oak Hill College, Klaus Mueller's room
|
||||
32182, the Ville, Dorm for Oak Hill College, Maria Lopez's room
|
||||
32192, the Ville, Dorm for Oak Hill College, Ayesha Khan's room
|
||||
32202, the Ville, Dorm for Oak Hill College, Wolfgang Schulz's room
|
||||
32143, the Ville, Dorm for Oak Hill College, man's bathroom
|
||||
32153, the Ville, Dorm for Oak Hill College, woman's bathroom
|
||||
32163, the Ville, Dorm for Oak Hill College, common room
|
||||
32173, the Ville, Dorm for Oak Hill College, kitchen
|
||||
32183, the Ville, Dorm for Oak Hill College, garden
|
||||
32205, the Ville, Tamara Taylor and Carmen Ortiz's house, Tamara Taylor's room
|
||||
32215, the Ville, Tamara Taylor and Carmen Ortiz's house, Carmen Ortiz's room
|
||||
32225, the Ville, Tamara Taylor and Carmen Ortiz's house, common room
|
||||
32235, the Ville, Tamara Taylor and Carmen Ortiz's house, kitchen
|
||||
32245, the Ville, Tamara Taylor and Carmen Ortiz's house, bathroom
|
||||
32255, the Ville, Tamara Taylor and Carmen Ortiz's house, garden
|
||||
32265, the Ville, Moreno family's house, Tom and Jane Moreno's bedroom
|
||||
32275, the Ville, Moreno family's house, empty bedroom
|
||||
32206, the Ville, Moreno family's house, common room
|
||||
32216, the Ville, Moreno family's house, kitchen
|
||||
32226, the Ville, Moreno family's house, bathroom
|
||||
32236, the Ville, Moreno family's house, garden
|
||||
32246, the Ville, Lin family's house, Mei and John Lin's bedroom
|
||||
32256, the Ville, Lin family's house, Eddy Lin's bedroom
|
||||
32266, the Ville, Lin family's house, common room
|
||||
32276, the Ville, Lin family's house, kitchen
|
||||
32207, the Ville, Lin family's house, bathroom
|
||||
32217, the Ville, Lin family's house, garden
|
||||
|
|
|
@ -0,0 +1,46 @@
|
|||
32227, the Ville, <all>, bed
|
||||
32237, the Ville, <all>, desk
|
||||
32247, the Ville, <all>, closet
|
||||
32257, the Ville, <all>, shelf
|
||||
32267, the Ville, <all>, easel
|
||||
32277, the Ville, <all>, bathroom sink
|
||||
32208, the Ville, <all>, shower
|
||||
32218, the Ville, <all>, toilet
|
||||
32228, the Ville, <all>, kitchen sink
|
||||
32238, the Ville, <all>, refrigerator
|
||||
32248, the Ville, <all>, toaster
|
||||
32258, the Ville, <all>, cooking area
|
||||
32268, the Ville, <all>, common room table
|
||||
32278, the Ville, <all>, common room sofa
|
||||
32209, the Ville, <all>, guitar
|
||||
32219, the Ville, <all>, microphone
|
||||
32229, the Ville, <all>, bar customer seating
|
||||
32239, the Ville, <all>, behind the bar counter
|
||||
32249, the Ville, <all>, behind the cafe counter
|
||||
32259, the Ville, <all>, cafe customer seating
|
||||
32269, the Ville, <all>, piano
|
||||
32279, the Ville, <all>, blackboard
|
||||
32210, the Ville, <all>, game console
|
||||
32220, the Ville, <all>, computer desk
|
||||
32230, the Ville, <all>, computer
|
||||
32240, the Ville, <all>, library sofa
|
||||
32250, the Ville, <all>, bookshelf
|
||||
32260, the Ville, <all>, library table
|
||||
32270, the Ville, <all>, classroom student seating
|
||||
32280, the Ville, <all>, classroom podium
|
||||
32211, the Ville, <all>, behind the pharmacy counter
|
||||
32221, the Ville, <all>, behind the grocery counter
|
||||
32231, the Ville, <all>, pharmacy store shelf
|
||||
32241, the Ville, <all>, grocery store shelf
|
||||
32251, the Ville, <all>, pharmacy store counter
|
||||
32261, the Ville, <all>, grocery store counter
|
||||
32271, the Ville, <all>, supply store product shelf
|
||||
32281, the Ville, <all>, behind the supply store counter
|
||||
32212, the Ville, <all>, supply store counter
|
||||
32222, the Ville, <all>, dorm garden
|
||||
32232, the Ville, <all>, house garden
|
||||
32242, the Ville, <all>, garden chair
|
||||
32252, the Ville, <all>, park garden
|
||||
32262, the Ville, <all>, harp
|
||||
32272, the Ville, <all>, lifting weight
|
||||
32282, the Ville, <all>, pool table
|
||||
|
|
|
@ -0,0 +1,19 @@
|
|||
32135, the Ville, artist's co-living space
|
||||
32145, the Ville, Arthur Burton's apartment
|
||||
32155, the Ville, Ryan Park's apartment
|
||||
32165, the Ville, Isabella Rodriguez's apartment
|
||||
32175, the Ville, Giorgio Rossi's apartment
|
||||
32185, the Ville, Carlos Gomez's apartment
|
||||
32195, the Ville, The Rose and Crown Pub
|
||||
32136, the Ville, Hobbs Cafe
|
||||
32146, the Ville, Oak Hill College
|
||||
32156, the Ville, Johnson Park
|
||||
32166, the Ville, Harvey Oak Supply Store
|
||||
32176, the Ville, The Willows Market and Pharmacy
|
||||
32186, the Ville, Adam Smith's house
|
||||
32196, the Ville, Yuriko Yamamoto's house
|
||||
32137, the Ville, Moore family's house
|
||||
32147, the Ville, Tamara Taylor and Carmen Ortiz's house
|
||||
32157, the Ville, Moreno family's house
|
||||
32167, the Ville, Lin family's house
|
||||
32177, the Ville, Dorm for Oak Hill College
|
||||
|
|
|
@ -0,0 +1,40 @@
|
|||
32285, the Ville, artist's co-living space, Latoya Williams's room, sp-A
|
||||
32295, the Ville, artist's co-living space, Rajiv Patel's room, sp-A
|
||||
32305, the Ville, artist's co-living space, Rajiv Patel's room, sp-B
|
||||
32315, the Ville, artist's co-living space, Abigail Chen's room, sp-A
|
||||
32286, the Ville, artist's co-living space, Francisco Lopez's room, sp-A
|
||||
32296, the Ville, artist's co-living space, Hailey Johnson's room, sp-A
|
||||
32306, the Ville, Arthur Burton's apartment, main room, sp-A
|
||||
32316, the Ville, Arthur Burton's apartment, main room, sp-B
|
||||
32287, the Ville, Ryan Park's apartment, main room, sp-A
|
||||
32297, the Ville, Ryan Park's apartment, main room, sp-B
|
||||
32307, the Ville, Isabella Rodriguez's apartment, main room, sp-A
|
||||
32317, the Ville, Isabella Rodriguez's apartment, main room, sp-B
|
||||
32288, the Ville, Giorgio Rossi's apartment, main room, sp-A
|
||||
32298, the Ville, Giorgio Rossi's apartment, main room, sp-B
|
||||
32308, the Ville, Carlos Gomez's apartment, main room, sp-A
|
||||
32318, the Ville, Carlos Gomez's apartment, main room, sp-B
|
||||
32289, the Ville, Adam Smith's house, main room, sp-A
|
||||
32299, the Ville, Adam Smith's house, main room, sp-B
|
||||
32309, the Ville, Yuriko Yamamoto's house, main room, sp-A
|
||||
32319, the Ville, Yuriko Yamamoto's house, main room, sp-B
|
||||
32290, the Ville, Moore family's house, main room, sp-A
|
||||
32300, the Ville, Moore family's house, main room, sp-B
|
||||
32310, the Ville, Tamara Taylor and Carmen Ortiz's house, Tamara Taylor's room, sp-A
|
||||
32320, the Ville, Tamara Taylor and Carmen Ortiz's house, Tamara Taylor's room, sp-B
|
||||
32291, the Ville, Tamara Taylor and Carmen Ortiz's house, Carmen Ortiz's room, sp-A
|
||||
32301, the Ville, Tamara Taylor and Carmen Ortiz's house, Carmen Ortiz's room, sp-B
|
||||
32311, the Ville, Moreno family's house, Tom and Jane Moreno's bedroom, sp-A
|
||||
32321, the Ville, Moreno family's house, Tom and Jane Moreno's bedroom, sp-B
|
||||
32292, the Ville, Moreno family's house, empty bedroom, sp-A
|
||||
32302, the Ville, Moreno family's house, empty bedroom, sp-B
|
||||
32312, the Ville, Lin family's house, Mei and John Lin's bedroom, sp-A
|
||||
32322, the Ville, Lin family's house, Mei and John Lin's bedroom, sp-B
|
||||
32293, the Ville, Lin family's house, Eddy Lin's bedroom, sp-A
|
||||
32303, the Ville, Lin family's house, Eddy Lin's bedroom, sp-B
|
||||
32313, the Ville, Dorm for Oak Hill College, Klaus Mueller's room, sp-A
|
||||
32323, the Ville, Dorm for Oak Hill College, Klaus Mueller's room, sp-B
|
||||
32294, the Ville, Dorm for Oak Hill College, Maria Lopez's room, sp-A
|
||||
32304, the Ville, Dorm for Oak Hill College, Ayesha Khan's room, sp-A
|
||||
32314, the Ville, Dorm for Oak Hill College, Ayesha Khan's room, sp-B
|
||||
32324, the Ville, Dorm for Oak Hill College, Wolfgang Schulz's room, sp-A
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
32134, the Ville
|
||||
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
File diff suppressed because one or more lines are too long
12
tests/metagpt/actions/ci/test_ask_review.py
Normal file
12
tests/metagpt/actions/ci/test_ask_review.py
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.actions.ci.ask_review import AskReview
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ask_review(mocker):
|
||||
mock_review_input = "confirm"
|
||||
mocker.patch("builtins.input", return_value=mock_review_input)
|
||||
rsp, confirmed = await AskReview().run()
|
||||
assert rsp == mock_review_input
|
||||
assert confirmed
|
||||
51
tests/metagpt/actions/ci/test_debug_code.py
Normal file
51
tests/metagpt/actions/ci/test_debug_code.py
Normal file
|
|
@ -0,0 +1,51 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Date : 1/11/2024 8:51 PM
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Desc :
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.actions.ci.debug_code import DebugCode
|
||||
from metagpt.schema import Message
|
||||
|
||||
ErrorStr = """Tested passed:
|
||||
|
||||
Tests failed:
|
||||
assert sort_array([1, 5, 2, 3, 4]) == [1, 2, 3, 4, 5] # output: [1, 2, 4, 3, 5]
|
||||
"""
|
||||
|
||||
CODE = """
|
||||
def sort_array(arr):
|
||||
# Helper function to count the number of ones in the binary representation
|
||||
def count_ones(n):
|
||||
return bin(n).count('1')
|
||||
|
||||
# Sort the array using a custom key function
|
||||
# The key function returns a tuple (number of ones, value) for each element
|
||||
# This ensures that if two elements have the same number of ones, they are sorted by their value
|
||||
sorted_arr = sorted(arr, key=lambda x: (count_ones(x), x))
|
||||
|
||||
return sorted_arr
|
||||
```
|
||||
"""
|
||||
|
||||
DebugContext = '''Solve the problem in Python:
|
||||
def sort_array(arr):
|
||||
"""
|
||||
In this Kata, you have to sort an array of non-negative integers according to
|
||||
number of ones in their binary representation in ascending order.
|
||||
For similar number of ones, sort based on decimal value.
|
||||
|
||||
It must be implemented like this:
|
||||
>>> sort_array([1, 5, 2, 3, 4]) == [1, 2, 3, 4, 5]
|
||||
>>> sort_array([-2, -3, -4, -5, -6]) == [-6, -5, -4, -3, -2]
|
||||
>>> sort_array([1, 0, 2, 3, 4]) [0, 1, 2, 3, 4]
|
||||
"""
|
||||
'''
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_debug_code():
|
||||
debug_context = Message(content=DebugContext)
|
||||
new_code = await DebugCode().run(context=debug_context, code=CODE, runtime_result=ErrorStr)
|
||||
assert "def sort_array(arr)" in new_code["code"]
|
||||
116
tests/metagpt/actions/ci/test_execute_nb_code.py
Normal file
116
tests/metagpt/actions/ci/test_execute_nb_code.py
Normal file
|
|
@ -0,0 +1,116 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.actions.ci.execute_nb_code import ExecuteNbCode, truncate
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_code_running():
|
||||
executor = ExecuteNbCode()
|
||||
output, is_success = await executor.run("print('hello world!')")
|
||||
assert is_success
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_split_code_running():
|
||||
executor = ExecuteNbCode()
|
||||
_ = await executor.run("x=1\ny=2")
|
||||
_ = await executor.run("z=x+y")
|
||||
output, is_success = await executor.run("assert z==3")
|
||||
assert is_success
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_error():
|
||||
executor = ExecuteNbCode()
|
||||
output, is_success = await executor.run("z=1/0")
|
||||
assert not is_success
|
||||
|
||||
|
||||
PLOT_CODE = """
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# 生成随机数据
|
||||
random_data = np.random.randn(1000) # 生成1000个符合标准正态分布的随机数
|
||||
|
||||
# 绘制直方图
|
||||
plt.hist(random_data, bins=30, density=True, alpha=0.7, color='blue', edgecolor='black')
|
||||
|
||||
# 添加标题和标签
|
||||
plt.title('Histogram of Random Data')
|
||||
plt.xlabel('Value')
|
||||
plt.ylabel('Frequency')
|
||||
|
||||
# 显示图形
|
||||
plt.show()
|
||||
plt.close()
|
||||
"""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_plotting_code():
|
||||
executor = ExecuteNbCode()
|
||||
output, is_success = await executor.run(PLOT_CODE)
|
||||
assert is_success
|
||||
|
||||
|
||||
def test_truncate():
|
||||
# 代码执行成功
|
||||
output, is_success = truncate("hello world", 5, True)
|
||||
assert "Truncated to show only first 5 characters\nhello" in output
|
||||
assert is_success
|
||||
# 代码执行失败
|
||||
output, is_success = truncate("hello world", 5, False)
|
||||
assert "Truncated to show only last 5 characters\nworld" in output
|
||||
assert not is_success
|
||||
# 异步
|
||||
output, is_success = truncate("<coroutine object", 5, True)
|
||||
assert not is_success
|
||||
assert "await" in output
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_with_timeout():
|
||||
executor = ExecuteNbCode(timeout=1)
|
||||
code = "import time; time.sleep(2)"
|
||||
message, success = await executor.run(code)
|
||||
assert not success
|
||||
assert message.startswith("Cell execution timed out")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_code_text():
|
||||
executor = ExecuteNbCode()
|
||||
message, success = await executor.run(code='print("This is a code!")', language="python")
|
||||
assert success
|
||||
assert message == "This is a code!\n"
|
||||
message, success = await executor.run(code="# This is a code!", language="markdown")
|
||||
assert success
|
||||
assert message == "# This is a code!"
|
||||
mix_text = "# Title!\n ```python\n print('This is a code!')```"
|
||||
message, success = await executor.run(code=mix_text, language="markdown")
|
||||
assert success
|
||||
assert message == mix_text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_terminate():
|
||||
executor = ExecuteNbCode()
|
||||
await executor.run(code='print("This is a code!")', language="python")
|
||||
is_kernel_alive = await executor.nb_client.km.is_alive()
|
||||
assert is_kernel_alive
|
||||
await executor.terminate()
|
||||
import time
|
||||
|
||||
time.sleep(2)
|
||||
assert executor.nb_client.km is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset():
|
||||
executor = ExecuteNbCode()
|
||||
await executor.run(code='print("This is a code!")', language="python")
|
||||
is_kernel_alive = await executor.nb_client.km.is_alive()
|
||||
assert is_kernel_alive
|
||||
await executor.reset()
|
||||
assert executor.nb_client.km is None
|
||||
46
tests/metagpt/actions/ci/test_ml_action.py
Normal file
46
tests/metagpt/actions/ci/test_ml_action.py
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.actions.ci.ml_action import WriteCodeWithToolsML
|
||||
from metagpt.schema import Plan, Task
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_code_with_tools():
|
||||
write_code_ml = WriteCodeWithToolsML()
|
||||
|
||||
task_map = {
|
||||
"1": Task(
|
||||
task_id="1",
|
||||
instruction="随机生成一个pandas DataFrame数据集",
|
||||
task_type="other",
|
||||
dependent_task_ids=[],
|
||||
code="""
|
||||
import pandas as pd
|
||||
df = pd.DataFrame({
|
||||
'a': [1, 2, 3, 4, 5],
|
||||
'b': [1.1, 2.2, 3.3, 4.4, np.nan],
|
||||
'c': ['aa', 'bb', 'cc', 'dd', 'ee'],
|
||||
'd': [1, 2, 3, 4, 5]
|
||||
})
|
||||
""",
|
||||
is_finished=True,
|
||||
),
|
||||
"2": Task(
|
||||
task_id="2",
|
||||
instruction="对数据集进行数据清洗",
|
||||
task_type="data_preprocess",
|
||||
dependent_task_ids=["1"],
|
||||
),
|
||||
}
|
||||
plan = Plan(
|
||||
goal="构造数据集并进行数据清洗",
|
||||
tasks=list(task_map.values()),
|
||||
task_map=task_map,
|
||||
current_task_id="2",
|
||||
)
|
||||
column_info = ""
|
||||
|
||||
_, code_with_ml = await write_code_ml.run([], plan, column_info)
|
||||
code_with_ml = code_with_ml["code"]
|
||||
assert len(code_with_ml) > 0
|
||||
print(code_with_ml)
|
||||
324
tests/metagpt/actions/ci/test_write_analysis_code.py
Normal file
324
tests/metagpt/actions/ci/test_write_analysis_code.py
Normal file
|
|
@ -0,0 +1,324 @@
|
|||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.actions.ci.execute_nb_code import ExecuteNbCode
|
||||
from metagpt.actions.ci.write_analysis_code import (
|
||||
WriteCodeWithoutTools,
|
||||
WriteCodeWithTools,
|
||||
)
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import Message, Plan, Task
|
||||
from metagpt.strategy.planner import STRUCTURAL_CONTEXT
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_code_by_list_plan():
|
||||
write_code = WriteCodeWithoutTools()
|
||||
execute_code = ExecuteNbCode()
|
||||
messages = []
|
||||
plan = ["随机生成一个pandas DataFrame时间序列", "绘制这个时间序列的直方图", "回顾已完成的任务", "求均值", "总结"]
|
||||
for task in plan:
|
||||
print(f"\n任务: {task}\n\n")
|
||||
messages.append(Message(task, role="assistant"))
|
||||
code = await write_code.run(messages)
|
||||
if task.startswith(("回顾", "总结")):
|
||||
assert code["language"] == "markdown"
|
||||
else:
|
||||
assert code["language"] == "python"
|
||||
messages.append(Message(code["code"], role="assistant"))
|
||||
assert len(code) > 0
|
||||
output, _ = await execute_code.run(**code)
|
||||
print(f"\n[Output]: 任务{task}的执行结果是: \n{output}\n")
|
||||
messages.append(output)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_recommendation():
|
||||
task = "clean and preprocess the data"
|
||||
available_tools = {
|
||||
"FillMissingValue": "Filling missing values",
|
||||
"SplitBins": "Bin continuous data into intervals and return the bin identifier encoded as an integer value",
|
||||
}
|
||||
write_code = WriteCodeWithTools()
|
||||
tools = await write_code._recommend_tool(task, available_tools)
|
||||
|
||||
assert len(tools) == 1
|
||||
assert "FillMissingValue" in tools
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_code_with_tools():
|
||||
write_code = WriteCodeWithTools()
|
||||
|
||||
requirement = "构造数据集并进行数据清洗"
|
||||
task_map = {
|
||||
"1": Task(
|
||||
task_id="1",
|
||||
instruction="随机生成一个pandas DataFrame数据集",
|
||||
task_type="other",
|
||||
dependent_task_ids=[],
|
||||
code="""
|
||||
import pandas as pd
|
||||
df = pd.DataFrame({
|
||||
'a': [1, 2, 3, 4, 5],
|
||||
'b': [1.1, 2.2, 3.3, 4.4, np.nan],
|
||||
'c': ['aa', 'bb', 'cc', 'dd', 'ee'],
|
||||
'd': [1, 2, 3, 4, 5]
|
||||
})
|
||||
""",
|
||||
is_finished=True,
|
||||
),
|
||||
"2": Task(
|
||||
task_id="2",
|
||||
instruction="对数据集进行数据清洗",
|
||||
task_type="data_preprocess",
|
||||
dependent_task_ids=["1"],
|
||||
),
|
||||
}
|
||||
plan = Plan(
|
||||
goal="构造数据集并进行数据清洗",
|
||||
tasks=list(task_map.values()),
|
||||
task_map=task_map,
|
||||
current_task_id="2",
|
||||
)
|
||||
|
||||
context = STRUCTURAL_CONTEXT.format(
|
||||
user_requirement=requirement,
|
||||
context=plan.context,
|
||||
tasks=list(task_map.values()),
|
||||
current_task=plan.current_task.model_dump_json(),
|
||||
)
|
||||
context_msg = [Message(content=context, role="user")]
|
||||
|
||||
code = await write_code.run(context_msg, plan)
|
||||
code = code["code"]
|
||||
assert len(code) > 0
|
||||
print(code)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_code_to_correct_error():
|
||||
structural_context = """
|
||||
## User Requirement
|
||||
read a dataset test.csv and print its head
|
||||
## Current Plan
|
||||
[
|
||||
{
|
||||
"task_id": "1",
|
||||
"dependent_task_ids": [],
|
||||
"instruction": "import pandas and load the dataset from 'test.csv'.",
|
||||
"task_type": "",
|
||||
"code": "",
|
||||
"result": "",
|
||||
"is_finished": false
|
||||
},
|
||||
{
|
||||
"task_id": "2",
|
||||
"dependent_task_ids": [
|
||||
"1"
|
||||
],
|
||||
"instruction": "Print the head of the dataset to display the first few rows.",
|
||||
"task_type": "",
|
||||
"code": "",
|
||||
"result": "",
|
||||
"is_finished": false
|
||||
}
|
||||
]
|
||||
## Current Task
|
||||
{"task_id": "1", "dependent_task_ids": [], "instruction": "import pandas and load the dataset from 'test.csv'.", "task_type": "", "code": "", "result": "", "is_finished": false}
|
||||
"""
|
||||
wrong_code = """import pandas as pd\ndata = pd.read_excel('test.csv')\ndata""" # use read_excel to read a csv
|
||||
error = """
|
||||
Traceback (most recent call last):
|
||||
File "<stdin>", line 2, in <module>
|
||||
File "/Users/gary/miniconda3/envs/py39_scratch/lib/python3.9/site-packages/pandas/io/excel/_base.py", line 478, in read_excel
|
||||
io = ExcelFile(io, storage_options=storage_options, engine=engine)
|
||||
File "/Users/gary/miniconda3/envs/py39_scratch/lib/python3.9/site-packages/pandas/io/excel/_base.py", line 1500, in __init__
|
||||
raise ValueError(
|
||||
ValueError: Excel file format cannot be determined, you must specify an engine manually.
|
||||
"""
|
||||
context = [
|
||||
Message(content=structural_context, role="user"),
|
||||
Message(content=wrong_code, role="assistant"),
|
||||
Message(content=error, role="user"),
|
||||
]
|
||||
new_code = await WriteCodeWithoutTools().run(context=context)
|
||||
new_code = new_code["code"]
|
||||
print(new_code)
|
||||
assert "read_csv" in new_code # should correct read_excel to read_csv
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_code_reuse_code_simple():
|
||||
structural_context = """
|
||||
## User Requirement
|
||||
read a dataset test.csv and print its head
|
||||
## Current Plan
|
||||
[
|
||||
{
|
||||
"task_id": "1",
|
||||
"dependent_task_ids": [],
|
||||
"instruction": "import pandas and load the dataset from 'test.csv'.",
|
||||
"task_type": "",
|
||||
"code": "import pandas as pd\ndata = pd.read_csv('test.csv')",
|
||||
"result": "",
|
||||
"is_finished": true
|
||||
},
|
||||
{
|
||||
"task_id": "2",
|
||||
"dependent_task_ids": [
|
||||
"1"
|
||||
],
|
||||
"instruction": "Print the head of the dataset to display the first few rows.",
|
||||
"task_type": "",
|
||||
"code": "",
|
||||
"result": "",
|
||||
"is_finished": false
|
||||
}
|
||||
]
|
||||
## Current Task
|
||||
{"task_id": "2", "dependent_task_ids": ["1"], "instruction": "Print the head of the dataset to display the first few rows.", "task_type": "", "code": "", "result": "", "is_finished": false}
|
||||
"""
|
||||
context = [
|
||||
Message(content=structural_context, role="user"),
|
||||
]
|
||||
code = await WriteCodeWithoutTools().run(context=context)
|
||||
code = code["code"]
|
||||
print(code)
|
||||
assert "pandas" not in code and "read_csv" not in code # should reuse import and read statement from previous one
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_code_reuse_code_long():
|
||||
"""test code reuse for long context"""
|
||||
|
||||
structural_context = """
|
||||
## User Requirement
|
||||
Run data analysis on sklearn Iris dataset, include a plot
|
||||
## Current Plan
|
||||
[
|
||||
{
|
||||
"task_id": "1",
|
||||
"dependent_task_ids": [],
|
||||
"instruction": "Load the Iris dataset from sklearn.",
|
||||
"task_type": "",
|
||||
"code": "from sklearn.datasets import load_iris\niris_data = load_iris()\niris_data['data'][0:5], iris_data['target'][0:5]",
|
||||
"result": "(array([[5.1, 3.5, 1.4, 0.2],\n [4.9, 3. , 1.4, 0.2],\n [4.7, 3.2, 1.3, 0.2],\n [4.6, 3.1, 1.5, 0.2],\n [5. , 3.6, 1.4, 0.2]]),\n array([0, 0, 0, 0, 0]))",
|
||||
"is_finished": true
|
||||
},
|
||||
{
|
||||
"task_id": "2",
|
||||
"dependent_task_ids": [
|
||||
"1"
|
||||
],
|
||||
"instruction": "Perform exploratory data analysis on the Iris dataset.",
|
||||
"task_type": "",
|
||||
"code": "",
|
||||
"result": "",
|
||||
"is_finished": false
|
||||
},
|
||||
{
|
||||
"task_id": "3",
|
||||
"dependent_task_ids": [
|
||||
"2"
|
||||
],
|
||||
"instruction": "Create a plot visualizing the Iris dataset features.",
|
||||
"task_type": "",
|
||||
"code": "",
|
||||
"result": "",
|
||||
"is_finished": false
|
||||
}
|
||||
]
|
||||
## Current Task
|
||||
{"task_id": "2", "dependent_task_ids": ["1"], "instruction": "Perform exploratory data analysis on the Iris dataset.", "task_type": "", "code": "", "result": "", "is_finished": false}
|
||||
"""
|
||||
context = [
|
||||
Message(content=structural_context, role="user"),
|
||||
]
|
||||
trials_num = 5
|
||||
trials = [WriteCodeWithoutTools().run(context=context, temperature=0.0) for _ in range(trials_num)]
|
||||
trial_results = await asyncio.gather(*trials)
|
||||
print(*trial_results, sep="\n\n***\n\n")
|
||||
success = [
|
||||
"load_iris" not in result["code"] and "iris_data" in result["code"] for result in trial_results
|
||||
] # should reuse iris_data from previous tasks
|
||||
success_rate = sum(success) / trials_num
|
||||
logger.info(f"success rate: {success_rate :.2f}")
|
||||
assert success_rate >= 0.8
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_code_reuse_code_long_for_wine():
|
||||
"""test code reuse for long context"""
|
||||
|
||||
structural_context = """
|
||||
## User Requirement
|
||||
Run data analysis on sklearn Wisconsin Breast Cancer dataset, include a plot, train a model to predict targets (20% as validation), and show validation accuracy
|
||||
## Current Plan
|
||||
[
|
||||
{
|
||||
"task_id": "1",
|
||||
"dependent_task_ids": [],
|
||||
"instruction": "Load the sklearn Wine recognition dataset and perform exploratory data analysis."
|
||||
"task_type": "",
|
||||
"code": "from sklearn.datasets import load_wine\n# Load the Wine recognition dataset\nwine_data = load_wine()\n# Perform exploratory data analysis\nwine_data.keys()",
|
||||
"result": "Truncated to show only the last 1000 characters\ndict_keys(['data', 'target', 'frame', 'target_names', 'DESCR', 'feature_names'])",
|
||||
"is_finished": true
|
||||
},
|
||||
{
|
||||
"task_id": "2",
|
||||
"dependent_task_ids": ["1"],
|
||||
"instruction": "Create a plot to visualize some aspect of the wine dataset."
|
||||
"task_type": "",
|
||||
"code": "",
|
||||
"result": "",
|
||||
"is_finished": false
|
||||
},
|
||||
{
|
||||
"task_id": "3",
|
||||
"dependent_task_ids": ["1"],
|
||||
"instruction": "Split the dataset into training and validation sets with a 20% validation size.",
|
||||
"task_type": "",
|
||||
"code": "",
|
||||
"result": "",
|
||||
"is_finished": false
|
||||
},
|
||||
{
|
||||
"task_id": "4",
|
||||
"dependent_task_ids": ["3"],
|
||||
"instruction": "Train a model on the training set to predict wine class.",
|
||||
"task_type": "",
|
||||
"code": "",
|
||||
"result": "",
|
||||
"is_finished": false
|
||||
},
|
||||
{
|
||||
"task_id": "5",
|
||||
"dependent_task_ids": ["4"],
|
||||
"instruction": "Evaluate the model on the validation set and report the accuracy.",
|
||||
"task_type": "",
|
||||
"code": "",
|
||||
"result": "",
|
||||
"is_finished": false
|
||||
}
|
||||
]
|
||||
## Current Task
|
||||
{"task_id": "2", "dependent_task_ids": ["1"], "instruction": "Create a plot to visualize some aspect of the Wine dataset.", "task_type": "", "code": "", "result": "", "is_finished": false}
|
||||
"""
|
||||
context = [
|
||||
Message(content=structural_context, role="user"),
|
||||
]
|
||||
trials_num = 5
|
||||
trials = [WriteCodeWithoutTools().run(context=context, temperature=0.0) for _ in range(trials_num)]
|
||||
trial_results = await asyncio.gather(*trials)
|
||||
print(*trial_results, sep="\n\n***\n\n")
|
||||
success = [
|
||||
"load_wine" not in result["code"] and "wine_data" in result["code"] for result in trial_results
|
||||
] # should reuse iris_data from previous tasks
|
||||
success_rate = sum(success) / trials_num
|
||||
logger.info(f"success rate: {success_rate :.2f}")
|
||||
assert success_rate >= 0.8
|
||||
34
tests/metagpt/actions/ci/test_write_plan.py
Normal file
34
tests/metagpt/actions/ci/test_write_plan.py
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.actions.ci.write_plan import (
|
||||
Plan,
|
||||
Task,
|
||||
WritePlan,
|
||||
precheck_update_plan_from_rsp,
|
||||
)
|
||||
from metagpt.schema import Message
|
||||
|
||||
|
||||
def test_precheck_update_plan_from_rsp():
|
||||
plan = Plan(goal="")
|
||||
plan.add_tasks([Task(task_id="1")])
|
||||
rsp = '[{"task_id": "2"}]'
|
||||
success, _ = precheck_update_plan_from_rsp(rsp, plan)
|
||||
assert success
|
||||
assert len(plan.tasks) == 1 and plan.tasks[0].task_id == "1" # precheck should not change the original one
|
||||
|
||||
invalid_rsp = "wrong"
|
||||
success, _ = precheck_update_plan_from_rsp(invalid_rsp, plan)
|
||||
assert not success
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("use_tools", [(False), (True)])
|
||||
async def test_write_plan(use_tools):
|
||||
rsp = await WritePlan().run(
|
||||
context=[Message("run analysis on sklearn iris dataset", role="user")], use_tools=use_tools
|
||||
)
|
||||
|
||||
assert "task_id" in rsp
|
||||
assert "instruction" in rsp
|
||||
assert "json" not in rsp # the output should be the content inside ```json ```
|
||||
|
|
@ -244,13 +244,19 @@ def test_create_model_class_with_mapping():
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_action_node_with_image():
|
||||
async def test_action_node_with_image(mocker):
|
||||
# add a mock to update model in unittest, due to the gloabl MockLLM
|
||||
def _cons_kwargs(self, messages: list[dict], timeout=3, **extra_kwargs) -> dict:
|
||||
kwargs = {"messages": messages, "temperature": 0.3, "model": "gpt-4-vision-preview"}
|
||||
return kwargs
|
||||
|
||||
invoice = ActionNode(
|
||||
key="invoice", expected_type=bool, instruction="if it's a invoice file, return True else False", example="False"
|
||||
)
|
||||
|
||||
invoice_path = Path(__file__).parent.joinpath("..", "..", "data", "invoices", "invoice-2.png")
|
||||
img_base64 = encode_image(invoice_path)
|
||||
mocker.patch("metagpt.provider.openai_api.OpenAILLM._cons_kwargs", _cons_kwargs)
|
||||
node = await invoice.fill(context="", llm=LLM(), images=[img_base64])
|
||||
assert node.instruct_content.invoice
|
||||
|
||||
|
|
|
|||
|
|
@ -1,49 +1,25 @@
|
|||
import json
|
||||
|
||||
import pytest
|
||||
from openai.types.chat import (
|
||||
ChatCompletion,
|
||||
ChatCompletionMessage,
|
||||
ChatCompletionMessageToolCall,
|
||||
)
|
||||
from openai.types.chat.chat_completion import Choice
|
||||
from openai.types.chat.chat_completion_message_tool_call import Function
|
||||
from PIL import Image
|
||||
|
||||
from metagpt.const import TEST_DATA_PATH
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider import OpenAILLM
|
||||
from metagpt.schema import UserMessage
|
||||
from tests.metagpt.provider.mock_llm_config import (
|
||||
mock_llm_config,
|
||||
mock_llm_config_proxy,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aask_code():
|
||||
llm = LLM()
|
||||
msg = [{"role": "user", "content": "Write a python hello world code."}]
|
||||
rsp = await llm.aask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"}
|
||||
|
||||
logger.info(rsp)
|
||||
assert "language" in rsp
|
||||
assert "code" in rsp
|
||||
assert len(rsp["code"]) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aask_code_str():
|
||||
llm = LLM()
|
||||
msg = "Write a python hello world code."
|
||||
rsp = await llm.aask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"}
|
||||
assert "language" in rsp
|
||||
assert "code" in rsp
|
||||
assert len(rsp["code"]) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aask_code_message():
|
||||
llm = LLM()
|
||||
msg = UserMessage("Write a python hello world code.")
|
||||
rsp = await llm.aask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"}
|
||||
assert "language" in rsp
|
||||
assert "code" in rsp
|
||||
assert len(rsp["code"]) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_to_speech():
|
||||
llm = LLM()
|
||||
|
|
@ -63,16 +39,41 @@ async def test_speech_to_text():
|
|||
assert "你好" == resp.text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gen_image():
|
||||
llm = LLM()
|
||||
model = "dall-e-3"
|
||||
prompt = 'a logo with word "MetaGPT"'
|
||||
images: list[Image] = await llm.gen_image(model=model, prompt=prompt)
|
||||
assert images[0].size == (1024, 1024)
|
||||
@pytest.fixture
|
||||
def tool_calls_rsp():
|
||||
function_rsps = [
|
||||
Function(arguments='{\n"language": "python",\n"code": "print(\'hello world\')"}', name="execute"),
|
||||
]
|
||||
tool_calls = [
|
||||
ChatCompletionMessageToolCall(type="function", id=f"call_{i}", function=f) for i, f in enumerate(function_rsps)
|
||||
]
|
||||
messages = [ChatCompletionMessage(content=None, role="assistant", tool_calls=[t]) for t in tool_calls]
|
||||
# 添加一个纯文本响应
|
||||
messages.append(
|
||||
ChatCompletionMessage(content="Completed a python code for hello world!", role="assistant", tool_calls=None)
|
||||
)
|
||||
# 添加 openai tool calls respond bug, code 出现在ChatCompletionMessage.content中
|
||||
messages.extend(
|
||||
[
|
||||
ChatCompletionMessage(content="```python\nprint('hello world')```", role="assistant", tool_calls=None),
|
||||
]
|
||||
)
|
||||
choices = [
|
||||
Choice(finish_reason="tool_calls", logprobs=None, index=i, message=msg) for i, msg in enumerate(messages)
|
||||
]
|
||||
return [
|
||||
ChatCompletion(id=str(i), choices=[c], created=i, model="gpt-4", object="chat.completion")
|
||||
for i, c in enumerate(choices)
|
||||
]
|
||||
|
||||
images: list[Image] = await llm.gen_image(model=model, prompt=prompt, resp_format="b64_json")
|
||||
assert images[0].size == (1024, 1024)
|
||||
|
||||
@pytest.fixture
|
||||
def json_decode_error():
|
||||
function_rsp = Function(arguments='{\n"language": \'python\',\n"code": "print(\'hello world\')"}', name="execute")
|
||||
tool_calls = [ChatCompletionMessageToolCall(type="function", id=f"call_{0}", function=function_rsp)]
|
||||
message = ChatCompletionMessage(content=None, role="assistant", tool_calls=tool_calls)
|
||||
choices = [Choice(finish_reason="tool_calls", logprobs=None, index=0, message=message)]
|
||||
return ChatCompletion(id="0", choices=choices, created=0, model="gpt-4", object="chat.completion")
|
||||
|
||||
|
||||
class TestOpenAI:
|
||||
|
|
@ -87,3 +88,36 @@ class TestOpenAI:
|
|||
instance = OpenAILLM(mock_llm_config_proxy)
|
||||
kwargs = instance._make_client_kwargs()
|
||||
assert "http_client" in kwargs
|
||||
|
||||
def test_get_choice_function_arguments_for_aask_code(self, tool_calls_rsp):
|
||||
instance = OpenAILLM(mock_llm_config_proxy)
|
||||
for i, rsp in enumerate(tool_calls_rsp):
|
||||
code = instance.get_choice_function_arguments(rsp)
|
||||
logger.info(f"\ntest get function call arguments {i}: {code}")
|
||||
assert "code" in code
|
||||
assert "language" in code
|
||||
assert "hello world" in code["code"]
|
||||
logger.info(f'code is : {code["code"]}')
|
||||
|
||||
if "Completed a python code for hello world!" == code["code"]:
|
||||
code["language"] == "markdown"
|
||||
else:
|
||||
code["language"] == "python"
|
||||
|
||||
def test_aask_code_json_decode_error(self, json_decode_error):
|
||||
instance = OpenAILLM(mock_llm_config)
|
||||
with pytest.raises(json.decoder.JSONDecodeError) as e:
|
||||
instance.get_choice_function_arguments(json_decode_error)
|
||||
assert "JSONDecodeError" in str(e)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gen_image():
|
||||
llm = LLM()
|
||||
model = "dall-e-3"
|
||||
prompt = 'a logo with word "MetaGPT"'
|
||||
images: list[Image] = await llm.gen_image(model=model, prompt=prompt)
|
||||
assert images[0].size == (1024, 1024)
|
||||
|
||||
images: list[Image] = await llm.gen_image(model=model, prompt=prompt, resp_format="b64_json")
|
||||
assert images[0].size == (1024, 1024)
|
||||
|
|
|
|||
19
tests/metagpt/roles/ci/test_code_interpreter.py
Normal file
19
tests/metagpt/roles/ci/test_code_interpreter.py
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.logs import logger
|
||||
from metagpt.roles.ci.code_interpreter import CodeInterpreter
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("auto_run", [(True), (False)])
|
||||
async def test_code_interpreter(mocker, auto_run):
|
||||
mocker.patch("metagpt.actions.ci.execute_nb_code.ExecuteNbCode.run", return_value=("a successful run", True))
|
||||
mocker.patch("builtins.input", return_value="confirm")
|
||||
|
||||
requirement = "Run data analysis on sklearn Iris dataset, include a plot"
|
||||
tools = []
|
||||
|
||||
ci = CodeInterpreter(auto_run=auto_run, use_tools=True, tools=tools)
|
||||
rsp = await ci.run(requirement)
|
||||
logger.info(rsp)
|
||||
assert len(rsp.content) > 0
|
||||
90
tests/metagpt/roles/ci/test_ml_engineer.py
Normal file
90
tests/metagpt/roles/ci/test_ml_engineer.py
Normal file
|
|
@ -0,0 +1,90 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.actions.ci.execute_nb_code import ExecuteNbCode
|
||||
from metagpt.logs import logger
|
||||
from metagpt.roles.ci.ml_engineer import MLEngineer
|
||||
from metagpt.schema import Message, Plan, Task
|
||||
from metagpt.tools.tool_type import ToolType
|
||||
from tests.metagpt.actions.ci.test_debug_code import CODE, DebugContext, ErrorStr
|
||||
|
||||
|
||||
def test_mle_init():
|
||||
ci = MLEngineer(goal="test", auto_run=True, use_tools=True, tools=["tool1", "tool2"])
|
||||
assert ci.tools == []
|
||||
|
||||
|
||||
MockPlan = Plan(
|
||||
goal="This is a titanic passenger survival dataset, your goal is to predict passenger survival outcome. The target column is Survived. Perform data analysis, data preprocessing, feature engineering, and modeling to predict the target. Report accuracy on the eval data. Train data path: 'tests/data/ml_datasets/titanic/split_train.csv', eval data path: 'tests/data/ml_datasets/titanic/split_eval.csv'.",
|
||||
context="",
|
||||
tasks=[
|
||||
Task(
|
||||
task_id="1",
|
||||
dependent_task_ids=[],
|
||||
instruction="Perform exploratory data analysis on the train dataset to understand the features and target variable.",
|
||||
task_type="eda",
|
||||
code="",
|
||||
result="",
|
||||
is_success=False,
|
||||
is_finished=False,
|
||||
)
|
||||
],
|
||||
task_map={
|
||||
"1": Task(
|
||||
task_id="1",
|
||||
dependent_task_ids=[],
|
||||
instruction="Perform exploratory data analysis on the train dataset to understand the features and target variable.",
|
||||
task_type="eda",
|
||||
code="",
|
||||
result="",
|
||||
is_success=False,
|
||||
is_finished=False,
|
||||
)
|
||||
},
|
||||
current_task_id="1",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mle_write_code(mocker):
|
||||
data_path = "tests/data/ml_datasets/titanic"
|
||||
|
||||
mle = MLEngineer(auto_run=True, use_tools=True)
|
||||
mle.planner.plan = MockPlan
|
||||
|
||||
code, _ = await mle._write_code()
|
||||
assert data_path in code["code"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mle_update_data_columns(mocker):
|
||||
mle = MLEngineer(auto_run=True, use_tools=True)
|
||||
mle.planner.plan = MockPlan
|
||||
|
||||
# manually update task type to test update
|
||||
mle.planner.plan.current_task.task_type = ToolType.DATA_PREPROCESS.value
|
||||
|
||||
result = await mle._update_data_columns()
|
||||
assert result is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mle_debug_code(mocker):
|
||||
mle = MLEngineer(auto_run=True, use_tools=True)
|
||||
mle.working_memory.add(Message(content=ErrorStr, cause_by=ExecuteNbCode))
|
||||
mle.latest_code = CODE
|
||||
mle.debug_context = DebugContext
|
||||
code, _ = await mle._write_code()
|
||||
assert len(code) > 0
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.asyncio
|
||||
async def test_ml_engineer():
|
||||
data_path = "tests/data/ml_datasets/titanic"
|
||||
requirement = f"This is a titanic passenger survival dataset, your goal is to predict passenger survival outcome. The target column is Survived. Perform data analysis, data preprocessing, feature engineering, and modeling to predict the target. Report accuracy on the eval data. Train data path: '{data_path}/split_train.csv', eval data path: '{data_path}/split_eval.csv'."
|
||||
tools = ["FillMissingValue", "CatCross", "dummy_tool"]
|
||||
|
||||
mle = MLEngineer(auto_run=True, use_tools=True, tools=tools)
|
||||
rsp = await mle.run(requirement)
|
||||
logger.info(rsp)
|
||||
assert len(rsp.content) > 0
|
||||
|
|
@ -142,6 +142,9 @@ def check_or_create_base_tag(project_path):
|
|||
# Initialize a Git repository
|
||||
subprocess.run(["git", "init"], check=True)
|
||||
|
||||
# Check if the .gitignore exists. If it doesn't exist, create .gitignore and add the comment
|
||||
subprocess.run(f"echo # Ignore these files or directories > {'.gitignore'}", shell=True)
|
||||
|
||||
# Check if the 'base' tag exists
|
||||
check_base_tag_cmd = ["git", "show-ref", "--verify", "--quiet", "refs/tags/base"]
|
||||
if subprocess.run(check_base_tag_cmd).returncode == 0:
|
||||
|
|
|
|||
|
|
@ -22,7 +22,9 @@ from metagpt.schema import (
|
|||
Document,
|
||||
Message,
|
||||
MessageQueue,
|
||||
Plan,
|
||||
SystemMessage,
|
||||
Task,
|
||||
UMLClassAttribute,
|
||||
UMLClassMethod,
|
||||
UMLClassView,
|
||||
|
|
@ -180,5 +182,173 @@ def test_class_view():
|
|||
)
|
||||
|
||||
|
||||
class TestPlan:
|
||||
def test_add_tasks_ordering(self):
|
||||
plan = Plan(goal="")
|
||||
|
||||
tasks = [
|
||||
Task(task_id="1", dependent_task_ids=["2", "3"], instruction="Third"),
|
||||
Task(task_id="2", instruction="First"),
|
||||
Task(task_id="3", dependent_task_ids=["2"], instruction="Second"),
|
||||
] # 2 -> 3 -> 1
|
||||
plan.add_tasks(tasks)
|
||||
|
||||
assert [task.task_id for task in plan.tasks] == ["2", "3", "1"]
|
||||
|
||||
def test_add_tasks_to_existing_no_common_prefix(self):
|
||||
plan = Plan(goal="")
|
||||
|
||||
tasks = [
|
||||
Task(task_id="1", dependent_task_ids=["2", "3"], instruction="Third"),
|
||||
Task(task_id="2", instruction="First"),
|
||||
Task(task_id="3", dependent_task_ids=["2"], instruction="Second", is_finished=True),
|
||||
] # 2 -> 3 -> 1
|
||||
plan.add_tasks(tasks)
|
||||
|
||||
new_tasks = [Task(task_id="3", instruction="")]
|
||||
plan.add_tasks(new_tasks)
|
||||
|
||||
assert [task.task_id for task in plan.tasks] == ["3"]
|
||||
assert not plan.tasks[0].is_finished # must be the new unfinished task
|
||||
|
||||
def test_add_tasks_to_existing_with_common_prefix(self):
|
||||
plan = Plan(goal="")
|
||||
|
||||
tasks = [
|
||||
Task(task_id="1", dependent_task_ids=["2", "3"], instruction="Third"),
|
||||
Task(task_id="2", instruction="First"),
|
||||
Task(task_id="3", dependent_task_ids=["2"], instruction="Second"),
|
||||
] # 2 -> 3 -> 1
|
||||
plan.add_tasks(tasks)
|
||||
plan.finish_current_task() # finish 2
|
||||
plan.finish_current_task() # finish 3
|
||||
|
||||
new_tasks = [
|
||||
Task(task_id="4", dependent_task_ids=["3"], instruction="Third"),
|
||||
Task(task_id="2", instruction="First"),
|
||||
Task(task_id="3", dependent_task_ids=["2"], instruction="Second"),
|
||||
] # 2 -> 3 -> 4, so the common prefix is 2 -> 3, and these two should be obtained from the existing tasks
|
||||
plan.add_tasks(new_tasks)
|
||||
|
||||
assert [task.task_id for task in plan.tasks] == ["2", "3", "4"]
|
||||
assert (
|
||||
plan.tasks[0].is_finished and plan.tasks[1].is_finished
|
||||
) # "2" and "3" should be the original finished one
|
||||
assert plan.current_task_id == "4"
|
||||
|
||||
def test_current_task(self):
|
||||
plan = Plan(goal="")
|
||||
tasks = [
|
||||
Task(task_id="1", dependent_task_ids=["2"], instruction="Second"),
|
||||
Task(task_id="2", instruction="First"),
|
||||
]
|
||||
plan.add_tasks(tasks)
|
||||
assert plan.current_task.task_id == "2"
|
||||
|
||||
def test_finish_task(self):
|
||||
plan = Plan(goal="")
|
||||
tasks = [
|
||||
Task(task_id="1", instruction="First"),
|
||||
Task(task_id="2", dependent_task_ids=["1"], instruction="Second"),
|
||||
]
|
||||
plan.add_tasks(tasks)
|
||||
plan.finish_current_task()
|
||||
assert plan.current_task.task_id == "2"
|
||||
|
||||
def test_finished_tasks(self):
|
||||
plan = Plan(goal="")
|
||||
tasks = [
|
||||
Task(task_id="1", instruction="First"),
|
||||
Task(task_id="2", dependent_task_ids=["1"], instruction="Second"),
|
||||
]
|
||||
plan.add_tasks(tasks)
|
||||
plan.finish_current_task()
|
||||
finished_tasks = plan.get_finished_tasks()
|
||||
assert len(finished_tasks) == 1
|
||||
assert finished_tasks[0].task_id == "1"
|
||||
|
||||
def test_reset_task_existing(self):
|
||||
plan = Plan(goal="")
|
||||
task = Task(task_id="1", instruction="Do something", code="print('Hello')", result="Hello", finished=True)
|
||||
plan.add_tasks([task])
|
||||
plan.reset_task("1")
|
||||
reset_task = plan.task_map["1"]
|
||||
assert reset_task.code == ""
|
||||
assert reset_task.result == ""
|
||||
assert not reset_task.is_finished
|
||||
|
||||
def test_reset_task_non_existing(self):
|
||||
plan = Plan(goal="")
|
||||
task = Task(task_id="1", instruction="Do something", code="print('Hello')", result="Hello", finished=True)
|
||||
plan.add_tasks([task])
|
||||
plan.reset_task("2") # Task with ID 2 does not exist
|
||||
assert "1" in plan.task_map
|
||||
assert "2" not in plan.task_map
|
||||
|
||||
def test_replace_task_with_dependents(self):
|
||||
plan = Plan(goal="")
|
||||
tasks = [
|
||||
Task(task_id="1", instruction="First Task", finished=True),
|
||||
Task(task_id="2", instruction="Second Task", dependent_task_ids=["1"], finished=True),
|
||||
]
|
||||
plan.add_tasks(tasks)
|
||||
new_task = Task(task_id="1", instruction="Updated First Task")
|
||||
plan.replace_task(new_task)
|
||||
assert plan.task_map["1"].instruction == "Updated First Task"
|
||||
assert not plan.task_map["2"].is_finished # Dependent task should be reset
|
||||
assert plan.task_map["2"].code == ""
|
||||
assert plan.task_map["2"].result == ""
|
||||
|
||||
def test_replace_task_non_existing(self):
|
||||
plan = Plan(goal="")
|
||||
task = Task(task_id="1", instruction="First Task")
|
||||
plan.add_tasks([task])
|
||||
new_task = Task(task_id="2", instruction="New Task")
|
||||
with pytest.raises(AssertionError):
|
||||
plan.replace_task(new_task) # Task with ID 2 does not exist in plan
|
||||
assert "1" in plan.task_map
|
||||
assert "2" not in plan.task_map
|
||||
|
||||
def test_append_task_with_valid_dependencies(self):
|
||||
plan = Plan(goal="Test")
|
||||
existing_task = [Task(task_id="1")]
|
||||
plan.add_tasks(existing_task)
|
||||
new_task = Task(task_id="2", dependent_task_ids=["1"])
|
||||
plan.append_task(new_task)
|
||||
assert plan.tasks[-1].task_id == "2"
|
||||
assert plan.task_map["2"] == new_task
|
||||
|
||||
def test_append_task_with_invalid_dependencies(self):
|
||||
new_task = Task(task_id="2", dependent_task_ids=["3"])
|
||||
plan = Plan(goal="Test")
|
||||
with pytest.raises(AssertionError):
|
||||
plan.append_task(new_task)
|
||||
|
||||
def test_append_task_without_dependencies(self):
|
||||
plan = Plan(goal="Test")
|
||||
existing_task = [Task(task_id="1")]
|
||||
plan.add_tasks(existing_task)
|
||||
|
||||
new_task = Task(task_id="2")
|
||||
plan.append_task(new_task)
|
||||
|
||||
assert len(plan.tasks) == 2
|
||||
assert plan.current_task_id == "1"
|
||||
|
||||
def test_append_task_updates_current_task(self):
|
||||
finished_task = Task(task_id="1", is_finished=True)
|
||||
new_task = Task(task_id="2")
|
||||
plan = Plan(goal="Test", tasks=[finished_task])
|
||||
plan.append_task(new_task)
|
||||
assert plan.current_task_id == "2"
|
||||
|
||||
def test_update_current_task(self):
|
||||
task1 = Task(task_id="1", is_finished=True)
|
||||
task2 = Task(task_id="2")
|
||||
plan = Plan(goal="Test", tasks=[task1, task2])
|
||||
plan._update_current_task()
|
||||
assert plan.current_task_id == "2"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
6
tests/metagpt/tools/libs/__init__.py
Normal file
6
tests/metagpt/tools/libs/__init__.py
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Time : 2024/1/11 16:14
|
||||
# @Author : lidanyang
|
||||
# @File : __init__.py
|
||||
# @Desc :
|
||||
111
tests/metagpt/tools/libs/test_data_preprocess.py
Normal file
111
tests/metagpt/tools/libs/test_data_preprocess.py
Normal file
|
|
@ -0,0 +1,111 @@
|
|||
from datetime import datetime
|
||||
|
||||
import numpy as np
|
||||
import numpy.testing as npt
|
||||
import pandas as pd
|
||||
import pytest
|
||||
|
||||
from metagpt.tools.libs.data_preprocess import (
|
||||
FillMissingValue,
|
||||
LabelEncode,
|
||||
MaxAbsScale,
|
||||
MinMaxScale,
|
||||
OneHotEncode,
|
||||
OrdinalEncode,
|
||||
RobustScale,
|
||||
StandardScale,
|
||||
get_column_info,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_datasets():
|
||||
return pd.DataFrame(
|
||||
{
|
||||
"num1": [1, 2, np.nan, 4, 5],
|
||||
"cat1": ["A", "B", np.nan, "D", "A"],
|
||||
"date1": [
|
||||
datetime(2020, 1, 1),
|
||||
datetime(2020, 1, 2),
|
||||
datetime(2020, 1, 3),
|
||||
datetime(2020, 1, 4),
|
||||
datetime(2020, 1, 5),
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_fill_missing_value(mock_datasets):
|
||||
fm = FillMissingValue(features=["num1"], strategy="mean")
|
||||
transformed = fm.fit_transform(mock_datasets.copy())
|
||||
|
||||
assert transformed["num1"].isnull().sum() == 0
|
||||
|
||||
|
||||
def test_min_max_scale(mock_datasets):
|
||||
mms = MinMaxScale(features=["num1"])
|
||||
transformed = mms.fit_transform(mock_datasets.copy())
|
||||
|
||||
npt.assert_allclose(transformed["num1"].min(), 0)
|
||||
npt.assert_allclose(transformed["num1"].max(), 1)
|
||||
|
||||
|
||||
def test_standard_scale(mock_datasets):
|
||||
ss = StandardScale(features=["num1"])
|
||||
transformed = ss.fit_transform(mock_datasets.copy())
|
||||
|
||||
assert int(transformed["num1"].mean()) == 0
|
||||
assert int(transformed["num1"].std()) == 1
|
||||
|
||||
|
||||
def test_max_abs_scale(mock_datasets):
|
||||
mas = MaxAbsScale(features=["num1"])
|
||||
transformed = mas.fit_transform(mock_datasets.copy())
|
||||
|
||||
npt.assert_allclose(transformed["num1"].abs().max(), 1)
|
||||
|
||||
|
||||
def test_robust_scale(mock_datasets):
|
||||
rs = RobustScale(features=["num1"])
|
||||
transformed = rs.fit_transform(mock_datasets.copy())
|
||||
|
||||
assert int(transformed["num1"].median()) == 0
|
||||
|
||||
|
||||
def test_ordinal_encode(mock_datasets):
|
||||
oe = OrdinalEncode(features=["cat1"])
|
||||
transformed = oe.fit_transform(mock_datasets.copy())
|
||||
|
||||
assert transformed["cat1"].max() == 2
|
||||
|
||||
|
||||
def test_one_hot_encode(mock_datasets):
|
||||
ohe = OneHotEncode(features=["cat1"])
|
||||
transformed = ohe.fit_transform(mock_datasets.copy())
|
||||
|
||||
assert transformed["cat1_A"].max() == 1
|
||||
|
||||
|
||||
def test_label_encode(mock_datasets):
|
||||
le = LabelEncode(features=["cat1"])
|
||||
transformed = le.fit_transform(mock_datasets.copy())
|
||||
|
||||
assert transformed["cat1"].max() == 3
|
||||
|
||||
# test transform with unseen data
|
||||
test = mock_datasets.copy()
|
||||
test["cat1"] = ["A", "B", "C", "D", "E"]
|
||||
transformed = le.transform(test)
|
||||
assert transformed["cat1"].max() == 4
|
||||
|
||||
|
||||
def test_get_column_info(mock_datasets):
|
||||
df = mock_datasets
|
||||
column_info = get_column_info(df)
|
||||
|
||||
assert column_info == {
|
||||
"Category": ["cat1"],
|
||||
"Numeric": ["num1"],
|
||||
"Datetime": ["date1"],
|
||||
"Others": [],
|
||||
}
|
||||
175
tests/metagpt/tools/libs/test_feature_engineering.py
Normal file
175
tests/metagpt/tools/libs/test_feature_engineering.py
Normal file
|
|
@ -0,0 +1,175 @@
|
|||
import numpy as np
|
||||
import pandas as pd
|
||||
import pytest
|
||||
from sklearn.datasets import fetch_california_housing, load_breast_cancer, load_iris
|
||||
|
||||
from metagpt.tools.libs.feature_engineering import (
|
||||
CatCount,
|
||||
CatCross,
|
||||
ExtractTimeComps,
|
||||
GeneralSelection,
|
||||
GroupStat,
|
||||
KFoldTargetMeanEncoder,
|
||||
PolynomialExpansion,
|
||||
SplitBins,
|
||||
TargetMeanEncoder,
|
||||
TreeBasedSelection,
|
||||
VarianceBasedSelection,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dataset():
|
||||
return pd.DataFrame(
|
||||
{
|
||||
"num1": [1, 2, np.nan, 4, 5, 6, 7, 3],
|
||||
"num2": [1, 3, 2, 1, np.nan, 5, 6, 4],
|
||||
"num3": [np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan],
|
||||
"cat1": ["A", "B", np.nan, "D", "E", "C", "B", "A"],
|
||||
"cat2": ["A", "A", "A", "A", "A", "A", "A", "A"],
|
||||
"date1": [
|
||||
"2020-01-01",
|
||||
"2020-01-02",
|
||||
"2020-01-03",
|
||||
"2020-01-04",
|
||||
"2020-01-05",
|
||||
"2020-01-06",
|
||||
"2020-01-07",
|
||||
"2020-01-08",
|
||||
],
|
||||
"label": [0, 1, 0, 1, 0, 1, 0, 1],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def load_sklearn_data(data_name):
|
||||
if data_name == "iris":
|
||||
data = load_iris()
|
||||
elif data_name == "breast_cancer":
|
||||
data = load_breast_cancer()
|
||||
elif data_name == "housing":
|
||||
data = fetch_california_housing()
|
||||
else:
|
||||
raise ValueError("data_name not supported")
|
||||
|
||||
X, y, feature_names = data.data, data.target, data.feature_names
|
||||
data = pd.DataFrame(X, columns=feature_names)
|
||||
data["label"] = y
|
||||
return data
|
||||
|
||||
|
||||
def test_polynomial_expansion(mock_dataset):
|
||||
pe = PolynomialExpansion(cols=["num1", "num2", "label"], degree=2, label_col="label")
|
||||
transformed = pe.fit_transform(mock_dataset)
|
||||
|
||||
assert len(transformed.columns) == len(mock_dataset.columns) + 3
|
||||
|
||||
# when too many columns
|
||||
data = load_sklearn_data("breast_cancer")
|
||||
cols = [c for c in data.columns if c != "label"]
|
||||
pe = PolynomialExpansion(cols=cols, degree=2, label_col="label")
|
||||
transformed = pe.fit_transform(data)
|
||||
|
||||
assert len(transformed.columns) == len(data.columns) + 55
|
||||
|
||||
|
||||
def test_cat_count(mock_dataset):
|
||||
cc = CatCount(col="cat1")
|
||||
transformed = cc.fit_transform(mock_dataset)
|
||||
|
||||
assert "cat1_cnt" in transformed.columns
|
||||
assert transformed["cat1_cnt"][0] == 2
|
||||
|
||||
|
||||
def test_target_mean_encoder(mock_dataset):
|
||||
tme = TargetMeanEncoder(col="cat1", label="label")
|
||||
transformed = tme.fit_transform(mock_dataset)
|
||||
|
||||
assert "cat1_target_mean" in transformed.columns
|
||||
assert transformed["cat1_target_mean"][0] == 0.5
|
||||
|
||||
|
||||
def test_kfold_target_mean_encoder(mock_dataset):
|
||||
kfme = KFoldTargetMeanEncoder(col="cat1", label="label")
|
||||
transformed = kfme.fit_transform(mock_dataset)
|
||||
|
||||
assert "cat1_kf_target_mean" in transformed.columns
|
||||
|
||||
|
||||
def test_cat_cross(mock_dataset):
|
||||
cc = CatCross(cols=["cat1", "cat2"])
|
||||
transformed = cc.fit_transform(mock_dataset)
|
||||
|
||||
assert "cat1_cat2" in transformed.columns
|
||||
|
||||
cc = CatCross(cols=["cat1", "cat2"], max_cat_num=3)
|
||||
transformed = cc.fit_transform(mock_dataset)
|
||||
|
||||
assert "cat1_cat2" not in transformed.columns
|
||||
|
||||
|
||||
def test_group_stat(mock_dataset):
|
||||
gs = GroupStat(group_col="cat1", agg_col="num1", agg_funcs=["mean", "sum"])
|
||||
transformed = gs.fit_transform(mock_dataset)
|
||||
|
||||
assert "num1_mean_by_cat1" in transformed.columns
|
||||
assert "num1_sum_by_cat1" in transformed.columns
|
||||
|
||||
|
||||
def test_split_bins(mock_dataset):
|
||||
sb = SplitBins(cols=["num1"])
|
||||
transformed = sb.fit_transform(mock_dataset)
|
||||
|
||||
assert transformed["num1"].nunique() <= 5
|
||||
assert all(0 <= x < 5 for x in transformed["num1"])
|
||||
|
||||
|
||||
def test_extract_time_comps(mock_dataset):
|
||||
time_comps = ["year", "month", "day", "hour", "dayofweek", "is_weekend"]
|
||||
etc = ExtractTimeComps(time_col="date1", time_comps=time_comps)
|
||||
transformed = etc.fit_transform(mock_dataset.copy())
|
||||
|
||||
for comp in time_comps:
|
||||
assert comp in transformed.columns
|
||||
assert transformed["year"][0] == 2020
|
||||
assert transformed["month"][0] == 1
|
||||
assert transformed["day"][0] == 1
|
||||
assert transformed["hour"][0] == 0
|
||||
assert transformed["dayofweek"][0] == 3
|
||||
assert transformed["is_weekend"][0] == 0
|
||||
|
||||
|
||||
def test_general_selection(mock_dataset):
|
||||
gs = GeneralSelection(label_col="label")
|
||||
transformed = gs.fit_transform(mock_dataset.copy())
|
||||
|
||||
assert "num3" not in transformed.columns
|
||||
assert "cat2" not in transformed.columns
|
||||
|
||||
|
||||
@pytest.mark.skip # skip because TreeBasedSelection needs lgb as dependency
|
||||
def test_tree_based_selection(mock_dataset):
|
||||
# regression
|
||||
data = load_sklearn_data("housing")
|
||||
tbs = TreeBasedSelection(label_col="label", task_type="reg")
|
||||
transformed = tbs.fit_transform(data)
|
||||
assert len(transformed.columns) > 1
|
||||
|
||||
# classification
|
||||
data = load_sklearn_data("breast_cancer")
|
||||
tbs = TreeBasedSelection(label_col="label", task_type="cls")
|
||||
transformed = tbs.fit_transform(data)
|
||||
assert len(transformed.columns) > 1
|
||||
|
||||
# multi-classification
|
||||
data = load_sklearn_data("iris")
|
||||
tbs = TreeBasedSelection(label_col="label", task_type="mcls")
|
||||
transformed = tbs.fit_transform(data)
|
||||
assert len(transformed.columns) > 1
|
||||
|
||||
|
||||
def test_variance_based_selection(mock_dataset):
|
||||
vbs = VarianceBasedSelection(label_col="label")
|
||||
transformed = vbs.fit_transform(mock_dataset.copy())
|
||||
|
||||
assert "num3" not in transformed.columns
|
||||
85
tests/metagpt/tools/libs/test_gpt_v_generator.py
Normal file
85
tests/metagpt/tools/libs/test_gpt_v_generator.py
Normal file
|
|
@ -0,0 +1,85 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2024/01/15
|
||||
@Author : mannaandpoem
|
||||
@File : test_gpt_v_generator.py
|
||||
"""
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt import logs
|
||||
from metagpt.const import METAGPT_ROOT
|
||||
from metagpt.tools.libs.gpt_v_generator import GPTvGenerator
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_webpage_filename_with_styles_and_scripts(mocker):
|
||||
mock_data = """```html\n<html>\n<script src="scripts.js"></script>
|
||||
<link rel="stylesheet" href="styles.css">\n</html>\n```\n
|
||||
```css\n/* styles.css */\n```\n
|
||||
```javascript\n// scripts.js\n```\n"""
|
||||
mocker.patch("metagpt.provider.base_llm.BaseLLM.aask", return_value=mock_data)
|
||||
return mocker
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_webpage_filename_with_style_and_script(mocker):
|
||||
mock_data = """```html\n<html>\n<script src="script.js"></script>
|
||||
<link rel="stylesheet" href="style.css">\n</html>\n```\n
|
||||
```css\n/* style.css */\n```\n
|
||||
```javascript\n// script.js\n```\n"""
|
||||
mocker.patch("metagpt.provider.base_llm.BaseLLM.aask", return_value=mock_data)
|
||||
return mocker
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_image_layout(mocker):
|
||||
image_layout = "The layout information of the sketch image is ..."
|
||||
mocker.patch("metagpt.provider.base_llm.BaseLLM.aask", return_value=image_layout)
|
||||
return mocker
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def image_path():
|
||||
return f"{METAGPT_ROOT}/docs/resources/workspace/content_rec_sys/resources/competitive_analysis.png"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_webpages(mock_webpage_filename_with_styles_and_scripts, image_path):
|
||||
generator = GPTvGenerator()
|
||||
rsp = await generator.generate_webpages(image_path=image_path)
|
||||
logs.logger.info(rsp)
|
||||
assert "html" in rsp
|
||||
assert "css" in rsp
|
||||
assert "javascript" in rsp
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_webpages_with_styles_and_scripts(mock_webpage_filename_with_styles_and_scripts, image_path):
|
||||
generator = GPTvGenerator()
|
||||
webpages = await generator.generate_webpages(image_path)
|
||||
webpages_dir = generator.save_webpages(image_path=image_path, webpages=webpages)
|
||||
logs.logger.info(webpages_dir)
|
||||
assert webpages_dir.exists()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_webpages_with_style_and_script(mock_webpage_filename_with_style_and_script, image_path):
|
||||
generator = GPTvGenerator()
|
||||
webpages = await generator.generate_webpages(image_path)
|
||||
webpages_dir = generator.save_webpages(image_path=image_path, webpages=webpages)
|
||||
logs.logger.info(webpages_dir)
|
||||
assert webpages_dir.exists()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_analyze_layout(mock_image_layout, image_path):
|
||||
layout = await GPTvGenerator().analyze_layout(Path(image_path))
|
||||
logs.logger.info(layout)
|
||||
assert layout
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
61
tests/metagpt/tools/libs/test_sd_engine.py
Normal file
61
tests/metagpt/tools/libs/test_sd_engine.py
Normal file
|
|
@ -0,0 +1,61 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Date : 1/10/2024 10:07 PM
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Desc :
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
|
||||
import pytest
|
||||
from PIL import Image, ImageDraw
|
||||
|
||||
from metagpt.tools.libs.sd_engine import SDEngine
|
||||
|
||||
|
||||
def generate_mock_image_data():
|
||||
# 创建一个简单的图片对象
|
||||
image = Image.new("RGB", (100, 100), color="white")
|
||||
draw = ImageDraw.Draw(image)
|
||||
draw.text((10, 10), "Mock Image", fill="black")
|
||||
|
||||
# 将图片转换为二进制数据
|
||||
with io.BytesIO() as buffer:
|
||||
image.save(buffer, format="PNG")
|
||||
image_binary = buffer.getvalue()
|
||||
|
||||
# 对图片二进制数据进行 base64 编码
|
||||
image_base64 = base64.b64encode(image_binary).decode("utf-8")
|
||||
|
||||
return image_base64
|
||||
|
||||
|
||||
def test_sd_tools(mocker):
|
||||
mock_response = mocker.MagicMock()
|
||||
mock_response.json.return_value = {"images": [generate_mock_image_data()]}
|
||||
mocker.patch("requests.Session.post", return_value=mock_response)
|
||||
|
||||
engine = SDEngine(sd_url="http://example_localhost:7860")
|
||||
prompt = "1boy, hansom"
|
||||
engine.construct_payload(prompt)
|
||||
engine.simple_run_t2i(engine.payload)
|
||||
|
||||
|
||||
def test_sd_construct_payload():
|
||||
engine = SDEngine(sd_url="http://example_localhost:7860")
|
||||
prompt = "1boy, hansom"
|
||||
engine.construct_payload(prompt)
|
||||
assert "negative_prompt" in engine.payload
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sd_asyn_t2i(mocker):
|
||||
mock_post = mocker.patch("aiohttp.ClientSession.post")
|
||||
mock_response = mocker.AsyncMock()
|
||||
mock_response.read.return_value = json.dumps({"images": [generate_mock_image_data()]})
|
||||
mock_post.return_value.__aenter__.return_value = mock_response
|
||||
|
||||
engine = SDEngine(sd_url="http://example_localhost:7860")
|
||||
prompt = "1boy, hansom"
|
||||
engine.construct_payload(prompt)
|
||||
await engine.run_t2i([engine.payload])
|
||||
assert "negative_prompt" in engine.payload
|
||||
23
tests/metagpt/tools/libs/test_web_scraping.py
Normal file
23
tests/metagpt/tools/libs/test_web_scraping.py
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.tools.libs.web_scraping import scrape_web_playwright
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scrape_web_playwright():
|
||||
test_url = "https://www.deepwisdom.ai"
|
||||
|
||||
result = await scrape_web_playwright(test_url)
|
||||
|
||||
# Assert that the result is a dictionary
|
||||
assert isinstance(result, dict)
|
||||
|
||||
# Assert that the result contains 'inner_text' and 'html' keys
|
||||
assert "inner_text" in result
|
||||
assert "html" in result
|
||||
|
||||
# Assert startswith and endswith
|
||||
assert not result["inner_text"].startswith(" ")
|
||||
assert not result["inner_text"].endswith(" ")
|
||||
assert not result["html"].startswith(" ")
|
||||
assert not result["html"].endswith(" ")
|
||||
175
tests/metagpt/tools/test_tool_convert.py
Normal file
175
tests/metagpt/tools/test_tool_convert.py
Normal file
|
|
@ -0,0 +1,175 @@
|
|||
import pandas as pd
|
||||
|
||||
from metagpt.tools.tool_convert import convert_code_to_tool_schema, docstring_to_schema
|
||||
|
||||
|
||||
def test_docstring_to_schema():
|
||||
docstring = """
|
||||
Some test desc.
|
||||
|
||||
Args:
|
||||
features (list): Columns to be processed.
|
||||
strategy (str, optional): The imputation strategy, notice 'mean' and 'median' can only be
|
||||
used for numeric features. Enum: ['mean', 'median', 'most_frequent', 'constant']. Defaults to 'mean'.
|
||||
fill_value (int, optional): Fill_value is used to replace all occurrences of missing_values.
|
||||
Defaults to None.
|
||||
Returns:
|
||||
pd.DataFrame: The transformed DataFrame.
|
||||
"""
|
||||
expected = {
|
||||
"description": "Some test desc.",
|
||||
"parameters": {
|
||||
"properties": {
|
||||
"features": {"type": "list", "description": "Columns to be processed."},
|
||||
"strategy": {
|
||||
"type": "str",
|
||||
"description": "The imputation strategy, notice 'mean' and 'median' can only be used for numeric features. Enum: ['mean', 'median', 'most_frequent', 'constant']. Defaults to 'mean'.",
|
||||
"default": "'mean'",
|
||||
"enum": ["'mean'", "'median'", "'most_frequent'", "'constant'"],
|
||||
},
|
||||
"fill_value": {
|
||||
"type": "int",
|
||||
"description": "Fill_value is used to replace all occurrences of missing_values. Defaults to None.",
|
||||
"default": "None",
|
||||
},
|
||||
},
|
||||
"required": ["features"],
|
||||
},
|
||||
"returns": [{"type": "pd.DataFrame", "description": "The transformed DataFrame."}],
|
||||
}
|
||||
schema = docstring_to_schema(docstring)
|
||||
assert schema == expected
|
||||
|
||||
|
||||
class DummyClass:
|
||||
"""
|
||||
Completing missing values with simple strategies.
|
||||
"""
|
||||
|
||||
def __init__(self, features: list, strategy: str = "mean", fill_value=None):
|
||||
"""
|
||||
Initialize self.
|
||||
|
||||
Args:
|
||||
features (list): Columns to be processed.
|
||||
strategy (str, optional): The imputation strategy, notice 'mean' and 'median' can only
|
||||
be used for numeric features. Enum: ['mean', 'median', 'most_frequent', 'constant']. Defaults to 'mean'.
|
||||
fill_value (int, optional): Fill_value is used to replace all occurrences of missing_values.
|
||||
Defaults to None.
|
||||
"""
|
||||
pass
|
||||
|
||||
def fit(self, df: pd.DataFrame):
|
||||
"""
|
||||
Fit the FillMissingValue model.
|
||||
|
||||
Args:
|
||||
df (pd.DataFrame): The input DataFrame.
|
||||
"""
|
||||
pass
|
||||
|
||||
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
Transform the input DataFrame with the fitted model.
|
||||
|
||||
Args:
|
||||
df (pd.DataFrame): The input DataFrame.
|
||||
|
||||
Returns:
|
||||
pd.DataFrame: The transformed DataFrame.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
def dummy_fn(df: pd.DataFrame) -> dict:
|
||||
"""
|
||||
Analyzes a DataFrame and categorizes its columns based on data types.
|
||||
|
||||
Args:
|
||||
df (pd.DataFrame): The DataFrame to be analyzed.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary with four keys ('Category', 'Numeric', 'Datetime', 'Others').
|
||||
Each key corresponds to a list of column names belonging to that category.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
async def dummy_async_fn(df: pd.DataFrame) -> dict:
|
||||
"""
|
||||
A dummy async function for test
|
||||
|
||||
Args:
|
||||
df (pd.DataFrame): test args.
|
||||
|
||||
Returns:
|
||||
dict: test returns.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
def test_convert_code_to_tool_schema_class():
|
||||
expected = {
|
||||
"type": "class",
|
||||
"description": "Completing missing values with simple strategies.",
|
||||
"methods": {
|
||||
"__init__": {
|
||||
"type": "function",
|
||||
"description": "Initialize self.",
|
||||
"parameters": {
|
||||
"properties": {
|
||||
"features": {"type": "list", "description": "Columns to be processed."},
|
||||
"strategy": {
|
||||
"type": "str",
|
||||
"description": "The imputation strategy, notice 'mean' and 'median' can only be used for numeric features. Enum: ['mean', 'median', 'most_frequent', 'constant']. Defaults to 'mean'.",
|
||||
"default": "'mean'",
|
||||
"enum": ["'mean'", "'median'", "'most_frequent'", "'constant'"],
|
||||
},
|
||||
"fill_value": {
|
||||
"type": "int",
|
||||
"description": "Fill_value is used to replace all occurrences of missing_values. Defaults to None.",
|
||||
"default": "None",
|
||||
},
|
||||
},
|
||||
"required": ["features"],
|
||||
},
|
||||
},
|
||||
"fit": {
|
||||
"type": "function",
|
||||
"description": "Fit the FillMissingValue model.",
|
||||
"parameters": {
|
||||
"properties": {"df": {"type": "pd.DataFrame", "description": "The input DataFrame."}},
|
||||
"required": ["df"],
|
||||
},
|
||||
},
|
||||
"transform": {
|
||||
"type": "function",
|
||||
"description": "Transform the input DataFrame with the fitted model.",
|
||||
"parameters": {
|
||||
"properties": {"df": {"type": "pd.DataFrame", "description": "The input DataFrame."}},
|
||||
"required": ["df"],
|
||||
},
|
||||
"returns": [{"type": "pd.DataFrame", "description": "The transformed DataFrame."}],
|
||||
},
|
||||
},
|
||||
}
|
||||
schema = convert_code_to_tool_schema(DummyClass)
|
||||
assert schema == expected
|
||||
|
||||
|
||||
def test_convert_code_to_tool_schema_function():
|
||||
expected = {
|
||||
"type": "function",
|
||||
"description": "Analyzes a DataFrame and categorizes its columns based on data types.",
|
||||
"parameters": {
|
||||
"properties": {"df": {"type": "pd.DataFrame", "description": "The DataFrame to be analyzed."}},
|
||||
"required": ["df"],
|
||||
},
|
||||
}
|
||||
schema = convert_code_to_tool_schema(dummy_fn)
|
||||
assert schema == expected
|
||||
|
||||
|
||||
def test_convert_code_to_tool_schema_async_function():
|
||||
schema = convert_code_to_tool_schema(dummy_async_fn)
|
||||
assert schema.get("type") == "async_function"
|
||||
102
tests/metagpt/tools/test_tool_registry.py
Normal file
102
tests/metagpt/tools/test_tool_registry.py
Normal file
|
|
@ -0,0 +1,102 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.tools.tool_registry import ToolRegistry
|
||||
from metagpt.tools.tool_type import ToolType
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tool_registry():
|
||||
return ToolRegistry()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tool_registry_full():
|
||||
return ToolRegistry(tool_types=ToolType)
|
||||
|
||||
|
||||
# Test Initialization
|
||||
def test_initialization(tool_registry):
|
||||
assert isinstance(tool_registry, ToolRegistry)
|
||||
assert tool_registry.tools == {}
|
||||
assert tool_registry.tool_types == {}
|
||||
assert tool_registry.tools_by_types == {}
|
||||
|
||||
|
||||
# Test Initialization with tool types
|
||||
def test_initialize_with_tool_types(tool_registry_full):
|
||||
assert isinstance(tool_registry_full, ToolRegistry)
|
||||
assert tool_registry_full.tools == {}
|
||||
assert tool_registry_full.tools_by_types == {}
|
||||
assert "data_preprocess" in tool_registry_full.tool_types
|
||||
|
||||
|
||||
class TestClassTool:
|
||||
"""test class"""
|
||||
|
||||
def test_class_fn(self):
|
||||
"""test class fn"""
|
||||
pass
|
||||
|
||||
|
||||
def test_fn():
|
||||
"""test function"""
|
||||
pass
|
||||
|
||||
|
||||
# Test Tool Registration Class
|
||||
def test_register_tool_class(tool_registry):
|
||||
tool_registry.register_tool("TestClassTool", "/path/to/tool", tool_source_object=TestClassTool)
|
||||
assert "TestClassTool" in tool_registry.tools
|
||||
|
||||
|
||||
# Test Tool Registration Function
|
||||
def test_register_tool_fn(tool_registry):
|
||||
tool_registry.register_tool("test_fn", "/path/to/tool", tool_source_object=test_fn)
|
||||
assert "test_fn" in tool_registry.tools
|
||||
|
||||
|
||||
# Test Tool Existence Checks
|
||||
def test_has_tool(tool_registry):
|
||||
tool_registry.register_tool("TestClassTool", "/path/to/tool", tool_source_object=TestClassTool)
|
||||
assert tool_registry.has_tool("TestClassTool")
|
||||
assert not tool_registry.has_tool("NonexistentTool")
|
||||
|
||||
|
||||
# Test Tool Retrieval
|
||||
def test_get_tool(tool_registry):
|
||||
tool_registry.register_tool("TestClassTool", "/path/to/tool", tool_source_object=TestClassTool)
|
||||
tool = tool_registry.get_tool("TestClassTool")
|
||||
assert tool is not None
|
||||
assert tool.name == "TestClassTool"
|
||||
assert tool.path == "/path/to/tool"
|
||||
assert "description" in tool.schemas
|
||||
|
||||
|
||||
# Similar tests for has_tool_type, get_tool_type, get_tools_by_type
|
||||
def test_has_tool_type(tool_registry_full):
|
||||
assert tool_registry_full.has_tool_type("data_preprocess")
|
||||
assert not tool_registry_full.has_tool_type("NonexistentType")
|
||||
|
||||
|
||||
def test_get_tool_type(tool_registry_full):
|
||||
retrieved_type = tool_registry_full.get_tool_type("data_preprocess")
|
||||
assert retrieved_type is not None
|
||||
assert retrieved_type.name == "data_preprocess"
|
||||
|
||||
|
||||
def test_get_tools_by_type(tool_registry):
|
||||
tool_type_name = "TestType"
|
||||
tool_name = "TestTool"
|
||||
tool_path = "/path/to/tool"
|
||||
|
||||
tool_registry.register_tool(tool_name, tool_path, tool_type=tool_type_name, tool_source_object=TestClassTool)
|
||||
|
||||
tools_by_type = tool_registry.get_tools_by_type(tool_type_name)
|
||||
assert tools_by_type is not None
|
||||
assert tool_name in tools_by_type
|
||||
|
||||
|
||||
# Test case for when the tool type does not exist
|
||||
def test_get_tools_by_nonexistent_type(tool_registry):
|
||||
tools_by_type = tool_registry.get_tools_by_type("NonexistentType")
|
||||
assert not tools_by_type
|
||||
|
|
@ -135,7 +135,7 @@ def test_repair_json_format():
|
|||
}
|
||||
"""
|
||||
target_output = """{
|
||||
"Language": "en_us",
|
||||
"Language": "en_us",
|
||||
"Programming Language": "Python"
|
||||
}"""
|
||||
output = repair_llm_raw_output(output=raw_output, req_keys=[None], repair_type=RepairType.JSON)
|
||||
|
|
@ -148,7 +148,7 @@ def test_repair_json_format():
|
|||
}
|
||||
"""
|
||||
target_output = """{
|
||||
"Language": "en_us",
|
||||
"Language": "en_us",
|
||||
"Programming Language": "Python"
|
||||
}"""
|
||||
output = repair_llm_raw_output(output=raw_output, req_keys=[None], repair_type=RepairType.JSON)
|
||||
|
|
@ -161,7 +161,7 @@ def test_repair_json_format():
|
|||
}
|
||||
"""
|
||||
target_output = """{
|
||||
"Language": "#en_us#",
|
||||
"Language": "#en_us#",
|
||||
"Programming Language": "//Python # Code // Language//"
|
||||
}"""
|
||||
output = repair_llm_raw_output(output=raw_output, req_keys=[None], repair_type=RepairType.JSON)
|
||||
|
|
|
|||
44
tests/metagpt/utils/test_save_code.py
Normal file
44
tests/metagpt/utils/test_save_code.py
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Date : 12/12/2023 4:17 PM
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Desc :
|
||||
|
||||
import nbformat
|
||||
import pytest
|
||||
|
||||
from metagpt.actions.ci.execute_nb_code import ExecuteNbCode
|
||||
from metagpt.utils.common import read_json_file
|
||||
from metagpt.utils.save_code import DATA_PATH, save_code_file
|
||||
|
||||
|
||||
def test_save_code_file_python():
|
||||
save_code_file("example", "print('Hello, World!')")
|
||||
file_path = DATA_PATH / "output" / "example" / "code.py"
|
||||
assert file_path.exists(), f"File does not exist: {file_path}"
|
||||
content = file_path.read_text()
|
||||
assert "print('Hello, World!')" in content, "File content does not match"
|
||||
|
||||
|
||||
def test_save_code_file_json():
|
||||
save_code_file("example_json", "print('Hello, JSON!')", file_format="json")
|
||||
file_path = DATA_PATH / "output" / "example_json" / "code.json"
|
||||
data = read_json_file(file_path)
|
||||
assert "code" in data, "JSON key 'code' is missing"
|
||||
assert data["code"] == "print('Hello, JSON!')", "JSON content does not match"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_code_file_notebook():
|
||||
code = "print('Hello, World!')"
|
||||
executor = ExecuteNbCode()
|
||||
await executor.run(code)
|
||||
# Save as a Notebook file
|
||||
save_code_file("example_nb", executor.nb, file_format="ipynb")
|
||||
file_path = DATA_PATH / "output" / "example_nb" / "code.ipynb"
|
||||
assert file_path.exists(), f"Notebook file does not exist: {file_path}"
|
||||
|
||||
# Additional checks specific to notebook format
|
||||
notebook = nbformat.read(file_path, as_version=4)
|
||||
assert len(notebook.cells) > 0, "Notebook should have at least one cell"
|
||||
first_cell_source = notebook.cells[0].source
|
||||
assert "print" in first_cell_source, "Notebook cell content does not match"
|
||||
|
|
@ -1,13 +1,22 @@
|
|||
import json
|
||||
from typing import Optional, Union
|
||||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.configs.llm_config import LLMType
|
||||
from metagpt.logs import log_llm_stream, logger
|
||||
from metagpt.provider.azure_openai_api import AzureOpenAILLM
|
||||
from metagpt.provider.openai_api import OpenAILLM
|
||||
from metagpt.schema import Message
|
||||
|
||||
OriginalLLM = OpenAILLM if config.llm.api_type == LLMType.OPENAI else AzureOpenAILLM
|
||||
|
||||
|
||||
class MockLLM(OpenAILLM):
|
||||
class MockLLM(OriginalLLM):
|
||||
def __init__(self, allow_open_api_call):
|
||||
super().__init__(config.get_openai_llm())
|
||||
original_llm_config = (
|
||||
config.get_openai_llm() if config.llm.api_type == LLMType.OPENAI else config.get_azure_llm()
|
||||
)
|
||||
super().__init__(original_llm_config)
|
||||
self.allow_open_api_call = allow_open_api_call
|
||||
self.rsp_cache: dict = {}
|
||||
self.rsp_candidates: list[dict] = [] # a test can have multiple calls with the same llm, thus a list
|
||||
|
|
@ -62,6 +71,14 @@ class MockLLM(OpenAILLM):
|
|||
context.append(self._assistant_msg(rsp_text))
|
||||
return self._extract_assistant_rsp(context)
|
||||
|
||||
async def original_aask_code(self, messages: Union[str, Message, list[dict]], **kwargs) -> dict:
|
||||
"""
|
||||
A copy of metagpt.provider.openai_api.OpenAILLM.aask_code, we can't use super().aask because it will be mocked.
|
||||
Since openai_api.OpenAILLM.aask_code is different from base_llm.BaseLLM.aask_code, we use the former.
|
||||
"""
|
||||
rsp = await self._achat_completion_function(messages, **kwargs)
|
||||
return self.get_choice_function_arguments(rsp)
|
||||
|
||||
async def aask(
|
||||
self,
|
||||
msg: str,
|
||||
|
|
@ -83,6 +100,12 @@ class MockLLM(OpenAILLM):
|
|||
rsp = await self._mock_rsp(msg_key, self.original_aask_batch, msgs, timeout)
|
||||
return rsp
|
||||
|
||||
async def aask_code(self, messages: Union[str, Message, list[dict]], **kwargs) -> dict:
|
||||
messages = self._process_message(messages)
|
||||
msg_key = json.dumps(messages, ensure_ascii=False)
|
||||
rsp = await self._mock_rsp(msg_key, self.original_aask_code, messages, **kwargs)
|
||||
return rsp
|
||||
|
||||
async def _mock_rsp(self, msg_key, ask_func, *args, **kwargs):
|
||||
if msg_key not in self.rsp_cache:
|
||||
if not self.allow_open_api_call:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue