add support for open-interpreter.

This commit is contained in:
刘棒棒 2023-09-13 15:47:51 +08:00
parent a90c4309a0
commit 754fa5ccbe
5 changed files with 315 additions and 0 deletions

View file

@ -0,0 +1,66 @@
from pathlib import Path
import traceback
from metagpt.actions.write_code import WriteCode
from metagpt.logs import logger
from metagpt.schema import Message
from metagpt.utils.highlight import highlight
CLONE_PROMPT = """
*context*
Please convert the function code ```{source_code}``` into the the function format: ```{template_func}```.
*Please Write code based on the following list and context*
1. Write code start with ```, and end with ```.
2. Please implement it in one function if possible, except for import statements. for exmaple:
```python
import pandas as pd
def run(*args) -> pd.DataFrame:
...
```
3. Do not use public member functions that do not exist in your design.
4. The output function name, input parameters and return value must be the same as ```{template_func}```.
5. Make sure the results before and after the code conversion are required to be exactly the same.
6. Don't repeat my context in your replies.
7. Return full results, for example, if the return value has df.head(), please return df.
8. If you must use a third-party package, use the most popular ones, for example: pandas, numpy, ta, ...
"""
class CloneFunction(WriteCode):
def __init__(self, name="CloneFunction", context: list[Message] = None, llm=None):
super().__init__(name, context, llm)
def _save(self, code_path, code):
if isinstance(code_path, str):
code_path = Path(code_path)
code_path.parent.mkdir(parents=True, exist_ok=True)
code_path.write_text(code)
logger.info(f"Saving Code to {code_path}")
async def run(self, template_func: str, source_code: str) -> str:
"""将source_code转换成template_func一样的入参和返回类型"""
prompt = CLONE_PROMPT.format(source_code=source_code, template_func=template_func)
logger.info(f"query for CloneFunction: \n {prompt}")
code = await self.write_code(prompt)
assert 'def' in code
logger.info(f'CloneFunction code is \n {highlight(code)}')
return code
def run_fucntion_code(func_code: str, func_name: str, *args, **kwargs):
"""执行函数类生成代码"""
try:
locals_ = {}
exec(func_code, locals_)
func = locals_[func_name]
return func(*args, **kwargs), ""
except Exception:
return "", traceback.format_exc()
def run_fucntion_script(code_script_path: str, func_name: str, *args, **kwargs):
"""从脚本中载入函数进行执行"""
if isinstance(code_script_path, str):
code_path = Path(code_script_path)
code = code_path.read_text(encoding='utf-8')
return run_fucntion_code(code, func_name, *args, **kwargs)

View file

@ -0,0 +1,128 @@
import re
from typing import List, Callable
from pathlib import Path
import wrapt
import textwrap
import inspect
from interpreter.interpreter import Interpreter
from metagpt.logs import logger
from metagpt.config import CONFIG
from metagpt.utils.highlight import highlight
from metagpt.actions.clone_function import CloneFunction, run_fucntion_code, run_fucntion_script
from metagpt.actions.run_code import RunCode
def extract_python_code(code: str):
"""提取代码块: 如果代码注释相同,则只保留最后一个代码块."""
# 使用正则表达式匹配注释块和相关的代码
pattern = r'(#\s[^\n]*)\n(.*?)(?=\n\s*#|$)'
matches = re.findall(pattern, code, re.DOTALL)
# 提取每个相同注释的最后一个代码块
unique_comments = {}
for comment, code_block in matches:
unique_comments[comment] = code_block
# 组装结果字符串
result_code = '\n'.join([f"{comment}\n{code_block}" for comment, code_block in unique_comments.items()])
header_code = code[:code.find("#")]
code = header_code + result_code
logger.info(f"Extract python code: \n {highlight(code)}")
return code
class OpenCodeInterpreter(object):
"""https://github.com/KillianLucas/open-interpreter"""
def __init__(self, auto_run: bool = True) -> None:
interpreter = Interpreter()
interpreter.auto_run = auto_run
interpreter.model = CONFIG.openai_api_model or "gpt-3.5-turbo"
interpreter.api_key = CONFIG.openai_api_key
self.interpreter = interpreter
def chat(self, query: str, reset: bool = True):
if reset:
self.interpreter.reset()
return self.interpreter.chat(query, return_messages=True)
@staticmethod
def extract_function(query_respond: List, function_name: str, *, language: str = 'python',
function_format: str = None) -> str:
"""create a function from query_respond."""
if language not in ('python'):
raise NotImplementedError(f"Not support to parse language {language}!")
# 定义函数形式
if function_format is None:
assert language == 'python', f"Expect python language for default function_format, but got {language}."
function_format = """def {function_name}():\n{code}"""
# 解析open-interpreter respond message中的代码模块
code = [item['function_call']['parsed_arguments']['code'] for item in query_respond
if "function_call" in item
and "parsed_arguments" in item["function_call"]
and 'language' in item["function_call"]['parsed_arguments']
and item["function_call"]['parsed_arguments']['language'] == language]
# 添加缩进
indented_code_str = textwrap.indent("\n".join(code), ' ' * 4)
# 按照代码注释, 返回去重后的代码
if language == "python":
return extract_python_code(function_format.format(function_name=function_name, code=indented_code_str))
def gen_query(func: Callable, args, kwargs) -> str:
# 函数的注释, 也就是query的主体
desc = func.__doc__
signature = inspect.signature(func)
# 获取函数wrapped的签名和入参的赋值
bound_args = signature.bind(*args, **kwargs)
bound_args.apply_defaults()
query = f"{desc}, {bound_args.arguments}, If you must use a third-party package, use the most popular ones, for example: pandas, numpy, ta, ..."
return query
def gen_template_fun(func: Callable) -> str:
return f"def {func.__name__}{str(inspect.signature(func))}\n # here is your code ..."
class OpenInterpreterDecorator(object):
def __init__(self, save_code: bool = False, code_file_path: str = None, clear_code: bool = False) -> None:
self.save_code = save_code
self.code_file_path = code_file_path
self.clear_code = clear_code
def __call__(self, wrapped):
@wrapt.decorator
async def wrapper(wrapped: Callable, instance, args, kwargs):
# 获取被装饰的函数名
func_name = wrapped.__name__
# 如果脚本在本地存在而且不需要clearcode则从脚本执行该函数
if Path(self.code_file_path).is_file() and not self.clear_code:
return run_fucntion_script(self.code_file_path, func_name, *args, **kwargs)
# 使用open-interpreter逐步生成代码
interpreter = OpenCodeInterpreter()
query = gen_query(wrapped, args, kwargs)
logger.info(f"query for OpenCodeInterpreter: \n {query}")
respond = interpreter.chat(query)
# 将open-interpreter逐步生成的代码组装成无入参的函数
func_code = interpreter.extract_function(respond, func_name)
# 把code克隆为wrapped即保持code和wrapped函数有相同的入参和返回值类型
template_func = gen_template_fun(wrapped)
cf = CloneFunction()
code = await cf.run(template_func=template_func, source_code=func_code)
# 在终端显示生成的函数
logger_code = highlight(code, "python")
logger.info(f"Creating following Python function:\n{logger_code}")
# 执行该函数
try:
res = run_fucntion_code(code, func_name, *args, **kwargs)
if self.save_code:
cf._save(self.code_file_path, code)
except Exception as e:
raise Exception("Could not evaluate Python code", e)
return res
return wrapper(wrapped)

View file

@ -0,0 +1,25 @@
# 添加代码语法高亮显示
from pygments import highlight as highlight_
from pygments.lexers import PythonLexer, SqlLexer
from pygments.formatters import TerminalFormatter, HtmlFormatter
def highlight(code: str, language: str = 'python', formatter: str = 'terminal'):
# 指定要高亮的语言
if language.lower() == 'python':
lexer = PythonLexer()
elif language.lower() == 'sql':
lexer = SqlLexer()
else:
raise ValueError(f"Unsupported language: {language}")
# 指定输出格式
if formatter.lower() == 'terminal':
formatter = TerminalFormatter()
elif formatter.lower() == 'html':
formatter = HtmlFormatter()
else:
raise ValueError(f"Unsupported formatter: {formatter}")
# 使用 Pygments 高亮代码片段
return highlight_(code, lexer, formatter)

View file

@ -0,0 +1,54 @@
import pytest
from metagpt.actions.clone_function import CloneFunction, run_fucntion_code
source_code = """
def user_indicator():
import pandas as pd
import ta
# 读取股票数据
stock_data = pd.read_csv('./tests/data/baba_stock.csv')
stock_data.head()
# 计算简单移动平均线
stock_data['SMA'] = ta.trend.sma_indicator(stock_data['Close'], window=6)
stock_data[['Date', 'Close', 'SMA']].head()
# 计算布林带
stock_data['bb_upper'], stock_data['bb_middle'], stock_data['bb_lower'] = ta.volatility.bollinger_hband_indicator(stock_data['Close'], window=20), ta.volatility.bollinger_mavg(stock_data['Close'], window=20), ta.volatility.bollinger_lband_indicator(stock_data['Close'], window=20)
stock_data[['Date', 'Close', 'bb_upper', 'bb_middle', 'bb_lower']].head()
"""
template_code = """
def stock_indicator(stock_path: str, indicators=['Simple Moving Average', 'BollingerBands', 'MACD]) -> pd.DataFrame:
import pandas as pd
# here is your code.
"""
def get_expected_res():
import pandas as pd
import ta
# 读取股票数据
stock_data = pd.read_csv('./tests/data/baba_stock.csv')
stock_data.head()
# 计算简单移动平均线
stock_data['SMA'] = ta.trend.sma_indicator(stock_data['Close'], window=6)
stock_data[['Date', 'Close', 'SMA']].head()
# 计算布林带
stock_data['bb_upper'], stock_data['bb_middle'], stock_data['bb_lower'] = ta.volatility.bollinger_hband_indicator(stock_data['Close'], window=20), ta.volatility.bollinger_mavg(stock_data['Close'], window=20), ta.volatility.bollinger_lband_indicator(stock_data['Close'], window=20)
stock_data[['Date', 'Close', 'bb_upper', 'bb_middle', 'bb_lower']].head()
return stock_data
@pytest.mark.asyncio
async def test_clone_function():
clone = CloneFunction()
code = await clone.run(template_code, source_code)
assert 'def ' in code
stock_path = './tests/data/baba_stock.csv'
df, msg = run_fucntion_code(code, 'stock_indicator', stock_path)
assert not msg
expected_df = get_expected_res()
assert df.equals(expected_df)

View file

@ -0,0 +1,42 @@
import pytest
import pandas as pd
from pathlib import Path
from tests.data import sales_desc, store_desc
from metagpt.tools.code_interpreter import OpenCodeInterpreter, OpenInterpreterDecorator
from metagpt.actions import Action
from metagpt.logs import logger
logger.add('./tests/data/test_ci.log')
stock = "./tests/data/baba_stock.csv"
# TODO: 需要一种表格数据格式能够支持schame管理的标注字段类型和字段含义。
class CreateStockIndicators(Action):
@OpenInterpreterDecorator(save_code=True, code_file_path="./tests/data/stock_indicators.py")
async def run(self, stock_path: str, indicators=['Simple Moving Average', 'BollingerBands']) -> pd.DataFrame:
"""对stock_path中的股票数据, 使用pandas和ta计算indicators中的技术指标, 返回带有技术指标的股票数据,不需要去除空值, 不需要安装任何包;
指标生成对应的三列: SMA, BB_upper, BB_lower
"""
...
@pytest.mark.asyncio
async def test_actions():
# 计算指标
indicators = ['Simple Moving Average', 'BollingerBands']
stocker = CreateStockIndicators()
df, msg = await stocker.run(stock, indicators=indicators)
assert isinstance(df, pd.DataFrame)
assert 'Close' in df.columns
assert 'Date' in df.columns
# 将df保存为文件将文件路径传入到下一个action
df_path = './tests/data/stock_indicators.csv'
df.to_csv(df_path)
assert Path(df_path).is_file()
# 可视化指标结果
figure_path = './tests/data/figure_ci.png'
ci_ploter = OpenCodeInterpreter()
ci_ploter.chat(f"使用seaborn对{df_path}中与股票布林带有关的数据列的Date, Close, SMA, BB_upper布林带上界, BB_lower布林带下界进行可视化, 可视化图片保存在{figure_path}中。不需要任何指标计算把Date列转换为日期类型。要求图片优美BB_upper, BB_lower之间使用合适的颜色填充。")
assert Path(figure_path).is_file()