add support for open-interpreter.

This commit is contained in:
刘棒棒 2023-09-13 15:47:51 +08:00
parent a90c4309a0
commit 754fa5ccbe
5 changed files with 315 additions and 0 deletions

View file

@ -0,0 +1,54 @@
import pytest
from metagpt.actions.clone_function import CloneFunction, run_fucntion_code
source_code = """
def user_indicator():
import pandas as pd
import ta
# 读取股票数据
stock_data = pd.read_csv('./tests/data/baba_stock.csv')
stock_data.head()
# 计算简单移动平均线
stock_data['SMA'] = ta.trend.sma_indicator(stock_data['Close'], window=6)
stock_data[['Date', 'Close', 'SMA']].head()
# 计算布林带
stock_data['bb_upper'], stock_data['bb_middle'], stock_data['bb_lower'] = ta.volatility.bollinger_hband_indicator(stock_data['Close'], window=20), ta.volatility.bollinger_mavg(stock_data['Close'], window=20), ta.volatility.bollinger_lband_indicator(stock_data['Close'], window=20)
stock_data[['Date', 'Close', 'bb_upper', 'bb_middle', 'bb_lower']].head()
"""
template_code = """
def stock_indicator(stock_path: str, indicators=['Simple Moving Average', 'BollingerBands', 'MACD]) -> pd.DataFrame:
import pandas as pd
# here is your code.
"""
def get_expected_res():
import pandas as pd
import ta
# 读取股票数据
stock_data = pd.read_csv('./tests/data/baba_stock.csv')
stock_data.head()
# 计算简单移动平均线
stock_data['SMA'] = ta.trend.sma_indicator(stock_data['Close'], window=6)
stock_data[['Date', 'Close', 'SMA']].head()
# 计算布林带
stock_data['bb_upper'], stock_data['bb_middle'], stock_data['bb_lower'] = ta.volatility.bollinger_hband_indicator(stock_data['Close'], window=20), ta.volatility.bollinger_mavg(stock_data['Close'], window=20), ta.volatility.bollinger_lband_indicator(stock_data['Close'], window=20)
stock_data[['Date', 'Close', 'bb_upper', 'bb_middle', 'bb_lower']].head()
return stock_data
@pytest.mark.asyncio
async def test_clone_function():
clone = CloneFunction()
code = await clone.run(template_code, source_code)
assert 'def ' in code
stock_path = './tests/data/baba_stock.csv'
df, msg = run_fucntion_code(code, 'stock_indicator', stock_path)
assert not msg
expected_df = get_expected_res()
assert df.equals(expected_df)

View file

@ -0,0 +1,42 @@
import pytest
import pandas as pd
from pathlib import Path
from tests.data import sales_desc, store_desc
from metagpt.tools.code_interpreter import OpenCodeInterpreter, OpenInterpreterDecorator
from metagpt.actions import Action
from metagpt.logs import logger
logger.add('./tests/data/test_ci.log')
stock = "./tests/data/baba_stock.csv"
# TODO: 需要一种表格数据格式能够支持schame管理的标注字段类型和字段含义。
class CreateStockIndicators(Action):
@OpenInterpreterDecorator(save_code=True, code_file_path="./tests/data/stock_indicators.py")
async def run(self, stock_path: str, indicators=['Simple Moving Average', 'BollingerBands']) -> pd.DataFrame:
"""对stock_path中的股票数据, 使用pandas和ta计算indicators中的技术指标, 返回带有技术指标的股票数据,不需要去除空值, 不需要安装任何包;
指标生成对应的三列: SMA, BB_upper, BB_lower
"""
...
@pytest.mark.asyncio
async def test_actions():
# 计算指标
indicators = ['Simple Moving Average', 'BollingerBands']
stocker = CreateStockIndicators()
df, msg = await stocker.run(stock, indicators=indicators)
assert isinstance(df, pd.DataFrame)
assert 'Close' in df.columns
assert 'Date' in df.columns
# 将df保存为文件将文件路径传入到下一个action
df_path = './tests/data/stock_indicators.csv'
df.to_csv(df_path)
assert Path(df_path).is_file()
# 可视化指标结果
figure_path = './tests/data/figure_ci.png'
ci_ploter = OpenCodeInterpreter()
ci_ploter.chat(f"使用seaborn对{df_path}中与股票布林带有关的数据列的Date, Close, SMA, BB_upper布林带上界, BB_lower布林带下界进行可视化, 可视化图片保存在{figure_path}中。不需要任何指标计算把Date列转换为日期类型。要求图片优美BB_upper, BB_lower之间使用合适的颜色填充。")
assert Path(figure_path).is_file()