Merge branch 'main' into feature/json_write_prd

This commit is contained in:
femto 2023-09-21 12:00:10 +08:00 committed by GitHub
commit 7b34c433cd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
69 changed files with 1911 additions and 56 deletions

View file

@ -0,0 +1,54 @@
import pytest
from metagpt.actions.clone_function import CloneFunction, run_function_code
source_code = """
import pandas as pd
import ta
def user_indicator():
# 读取股票数据
stock_data = pd.read_csv('./tests/data/baba_stock.csv')
stock_data.head()
# 计算简单移动平均线
stock_data['SMA'] = ta.trend.sma_indicator(stock_data['Close'], window=6)
stock_data[['Date', 'Close', 'SMA']].head()
# 计算布林带
stock_data['bb_upper'], stock_data['bb_middle'], stock_data['bb_lower'] = ta.volatility.bollinger_hband_indicator(stock_data['Close'], window=20), ta.volatility.bollinger_mavg(stock_data['Close'], window=20), ta.volatility.bollinger_lband_indicator(stock_data['Close'], window=20)
stock_data[['Date', 'Close', 'bb_upper', 'bb_middle', 'bb_lower']].head()
"""
template_code = """
def stock_indicator(stock_path: str, indicators=['Simple Moving Average', 'BollingerBands', 'MACD]) -> pd.DataFrame:
import pandas as pd
# here is your code.
"""
def get_expected_res():
import pandas as pd
import ta
# 读取股票数据
stock_data = pd.read_csv('./tests/data/baba_stock.csv')
stock_data.head()
# 计算简单移动平均线
stock_data['SMA'] = ta.trend.sma_indicator(stock_data['Close'], window=6)
stock_data[['Date', 'Close', 'SMA']].head()
# 计算布林带
stock_data['bb_upper'], stock_data['bb_middle'], stock_data['bb_lower'] = ta.volatility.bollinger_hband_indicator(stock_data['Close'], window=20), ta.volatility.bollinger_mavg(stock_data['Close'], window=20), ta.volatility.bollinger_lband_indicator(stock_data['Close'], window=20)
stock_data[['Date', 'Close', 'bb_upper', 'bb_middle', 'bb_lower']].head()
return stock_data
@pytest.mark.asyncio
async def test_clone_function():
clone = CloneFunction()
code = await clone.run(template_code, source_code)
assert 'def ' in code
stock_path = './tests/data/baba_stock.csv'
df, msg = run_function_code(code, 'stock_indicator', stock_path)
assert not msg
expected_df = get_expected_res()
assert df.equals(expected_df)

View file

@ -0,0 +1,7 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/9/16 20:03
@Author : femto Zheng
@File : __init__.py
"""

View file

@ -0,0 +1,29 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/9/16 20:03
@Author : femto Zheng
@File : test_basic_planner.py
"""
import pytest
from semantic_kernel.core_skills import FileIOSkill, MathSkill, TextSkill, TimeSkill
from semantic_kernel.planning.action_planner.action_planner import ActionPlanner
from metagpt.actions import BossRequirement
from metagpt.roles.sk_agent import SkAgent
from metagpt.schema import Message
@pytest.mark.asyncio
async def test_action_planner():
role = SkAgent(planner_cls=ActionPlanner)
# let's give the agent 4 skills
role.import_skill(MathSkill(), "math")
role.import_skill(FileIOSkill(), "fileIO")
role.import_skill(TimeSkill(), "time")
role.import_skill(TextSkill(), "text")
task = "What is the sum of 110 and 990?"
role.recv(Message(content=task, cause_by=BossRequirement))
await role._think() # it will choose mathskill.Add
assert "1100" == (await role._act()).content

View file

@ -0,0 +1,34 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/9/16 20:03
@Author : femto Zheng
@File : test_basic_planner.py
"""
import pytest
from semantic_kernel.core_skills import TextSkill
from metagpt.actions import BossRequirement
from metagpt.const import SKILL_DIRECTORY
from metagpt.roles.sk_agent import SkAgent
from metagpt.schema import Message
@pytest.mark.asyncio
async def test_basic_planner():
task = """
Tomorrow is Valentine's day. I need to come up with a few date ideas. She speaks French so write it in French.
Convert the text to uppercase"""
role = SkAgent()
# let's give the agent some skills
role.import_semantic_skill_from_directory(SKILL_DIRECTORY, "SummarizeSkill")
role.import_semantic_skill_from_directory(SKILL_DIRECTORY, "WriterSkill")
role.import_skill(TextSkill(), "TextSkill")
# using BasicPlanner
role.recv(Message(content=task, cause_by=BossRequirement))
await role._think()
# assuming sk_agent will think he needs WriterSkill.Brainstorm and WriterSkill.Translate
assert "WriterSkill.Brainstorm" in role.plan.generated_plan.result
assert "WriterSkill.Translate" in role.plan.generated_plan.result
# assert "SALUT" in (await role._act()).content #content will be some French

View file

@ -0,0 +1,42 @@
import pytest
import pandas as pd
from pathlib import Path
from tests.data import sales_desc, store_desc
from metagpt.tools.code_interpreter import OpenCodeInterpreter, OpenInterpreterDecorator
from metagpt.actions import Action
from metagpt.logs import logger
logger.add('./tests/data/test_ci.log')
stock = "./tests/data/baba_stock.csv"
# TODO: 需要一种表格数据格式能够支持schame管理的标注字段类型和字段含义。
class CreateStockIndicators(Action):
@OpenInterpreterDecorator(save_code=True, code_file_path="./tests/data/stock_indicators.py")
async def run(self, stock_path: str, indicators=['Simple Moving Average', 'BollingerBands']) -> pd.DataFrame:
"""对stock_path中的股票数据, 使用pandas和ta计算indicators中的技术指标, 返回带有技术指标的股票数据,不需要去除空值, 不需要安装任何包;
指标生成对应的三列: SMA, BB_upper, BB_lower
"""
...
@pytest.mark.asyncio
async def test_actions():
# 计算指标
indicators = ['Simple Moving Average', 'BollingerBands']
stocker = CreateStockIndicators()
df, msg = await stocker.run(stock, indicators=indicators)
assert isinstance(df, pd.DataFrame)
assert 'Close' in df.columns
assert 'Date' in df.columns
# 将df保存为文件将文件路径传入到下一个action
df_path = './tests/data/stock_indicators.csv'
df.to_csv(df_path)
assert Path(df_path).is_file()
# 可视化指标结果
figure_path = './tests/data/figure_ci.png'
ci_ploter = OpenCodeInterpreter()
ci_ploter.chat(f"使用seaborn对{df_path}中与股票布林带有关的数据列的Date, Close, SMA, BB_upper布林带上界, BB_lower布林带下界进行可视化, 可视化图片保存在{figure_path}中。不需要任何指标计算把Date列转换为日期类型。要求图片优美BB_upper, BB_lower之间使用合适的颜色填充。")
assert Path(figure_path).is_file()

View file

@ -7,7 +7,6 @@
"""
from pathlib import Path
import aiofiles
import pytest
from metagpt.utils.file import File
@ -18,10 +17,10 @@ from metagpt.utils.file import File
("root_path", "filename", "content"),
[(Path("/code/MetaGPT/data/tutorial_docx/2023-09-07_17-05-20"), "test.md", "Hello World!")]
)
async def test_write_file(root_path: Path, filename: str, content: bytes):
async def test_write_and_read_file(root_path: Path, filename: str, content: bytes):
full_file_name = await File.write(root_path=root_path, filename=filename, content=content.encode('utf-8'))
assert isinstance(full_file_name, Path)
assert root_path / filename == full_file_name
async with aiofiles.open(full_file_name, mode="r") as reader:
body = await reader.read()
assert body == content
file_data = await File.read(full_file_name)
assert file_data.decode("utf-8") == content

View file

@ -5,7 +5,7 @@
@Author : chengmaoyu
@File : test_output_parser.py
"""
from typing import List, Tuple
from typing import List, Tuple, Union
import pytest
@ -64,6 +64,59 @@ def test_parse_data():
assert OutputParser.parse_data(test_data) == expected_result
@pytest.mark.parametrize(
("text", "data_type", "parsed_data", "expected_exception"),
[
(
"""xxx [1, 2, ["a", "b", [3, 4]], {"x": 5, "y": [6, 7]}] xxx""",
list,
[1, 2, ["a", "b", [3, 4]], {"x": 5, "y": [6, 7]}],
None,
),
(
"""xxx ["1", "2", "3"] xxx \n xxx \t xx""",
list,
["1", "2", "3"],
None,
),
(
"""{"title": "a", "directory": {"sub_dir1": ["title1, title2"]}, "sub_dir2": [1, 2]}""",
dict,
{"title": "a", "directory": {"sub_dir1": ["title1, title2"]}, "sub_dir2": [1, 2]},
None,
),
(
"""xxx {"title": "x", \n \t "directory": ["x", \n "y"]} xxx \n xxx \t xx""",
dict,
{"title": "x", "directory": ["x", "y"]},
None,
),
(
"""xxx xx""",
list,
None,
Exception,
),
(
"""xxx [1, 2, []xx""",
list,
None,
Exception,
),
]
)
def test_extract_struct(text: str, data_type: Union[type(list), type(dict)], parsed_data: Union[list, dict], expected_exception):
def case():
resp = OutputParser.extract_struct(text, data_type)
assert resp == parsed_data
if expected_exception:
with pytest.raises(expected_exception):
case()
else:
case()
if __name__ == '__main__':
t_text = '''
## Required Python third-party packages