Merge pull request #315 from orange-crow/add_open_interpreter

add support for open-interpreter.
This commit is contained in:
stellaHSR 2023-09-21 10:48:11 +08:00 committed by GitHub
commit f6b7a7f1d0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 317 additions and 1 deletions

View file

@ -0,0 +1,65 @@
from pathlib import Path
import traceback
from metagpt.actions.write_code import WriteCode
from metagpt.logs import logger
from metagpt.schema import Message
from metagpt.utils.highlight import highlight
CLONE_PROMPT = """
*context*
Please convert the function code ```{source_code}``` into the the function format: ```{template_func}```.
*Please Write code based on the following list and context*
1. Write code start with ```, and end with ```.
2. Please implement it in one function if possible, except for import statements. for exmaple:
```python
import pandas as pd
def run(*args) -> pd.DataFrame:
...
```
3. Do not use public member functions that do not exist in your design.
4. The output function name, input parameters and return value must be the same as ```{template_func}```.
5. Make sure the results before and after the code conversion are required to be exactly the same.
6. Don't repeat my context in your replies.
7. Return full results, for example, if the return value has df.head(), please return df.
8. If you must use a third-party package, use the most popular ones, for example: pandas, numpy, ta, ...
"""
class CloneFunction(WriteCode):
def __init__(self, name="CloneFunction", context: list[Message] = None, llm=None):
super().__init__(name, context, llm)
def _save(self, code_path, code):
if isinstance(code_path, str):
code_path = Path(code_path)
code_path.parent.mkdir(parents=True, exist_ok=True)
code_path.write_text(code)
logger.info(f"Saving Code to {code_path}")
async def run(self, template_func: str, source_code: str) -> str:
"""将source_code转换成template_func一样的入参和返回类型"""
prompt = CLONE_PROMPT.format(source_code=source_code, template_func=template_func)
logger.info(f"query for CloneFunction: \n {prompt}")
code = await self.write_code(prompt)
logger.info(f'CloneFunction code is \n {highlight(code)}')
return code
def run_function_code(func_code: str, func_name: str, *args, **kwargs):
"""Run function code from string code."""
try:
locals_ = {}
exec(func_code, locals_)
func = locals_[func_name]
return func(*args, **kwargs), ""
except Exception:
return "", traceback.format_exc()
def run_function_script(code_script_path: str, func_name: str, *args, **kwargs):
"""Run function code from script."""
if isinstance(code_script_path, str):
code_path = Path(code_script_path)
code = code_path.read_text(encoding='utf-8')
return run_function_code(code, func_name, *args, **kwargs)

View file

@ -0,0 +1,129 @@
import re
from typing import List, Callable
from pathlib import Path
import wrapt
import textwrap
import inspect
from interpreter.interpreter import Interpreter
from metagpt.logs import logger
from metagpt.config import CONFIG
from metagpt.utils.highlight import highlight
from metagpt.actions.clone_function import CloneFunction, run_function_code, run_function_script
def extract_python_code(code: str):
"""Extract code blocks: If the code comments are the same, only the last code block is kept."""
# Use regular expressions to match comment blocks and related code.
pattern = r'(#\s[^\n]*)\n(.*?)(?=\n\s*#|$)'
matches = re.findall(pattern, code, re.DOTALL)
# Extract the last code block when encountering the same comment.
unique_comments = {}
for comment, code_block in matches:
unique_comments[comment] = code_block
# concatenate into functional form
result_code = '\n'.join([f"{comment}\n{code_block}" for comment, code_block in unique_comments.items()])
header_code = code[:code.find("#")]
code = header_code + result_code
logger.info(f"Extract python code: \n {highlight(code)}")
return code
class OpenCodeInterpreter(object):
"""https://github.com/KillianLucas/open-interpreter"""
def __init__(self, auto_run: bool = True) -> None:
interpreter = Interpreter()
interpreter.auto_run = auto_run
interpreter.model = CONFIG.openai_api_model or "gpt-3.5-turbo"
interpreter.api_key = CONFIG.openai_api_key
interpreter.api_base = CONFIG.openai_api_base
self.interpreter = interpreter
def chat(self, query: str, reset: bool = True):
if reset:
self.interpreter.reset()
return self.interpreter.chat(query, return_messages=True)
@staticmethod
def extract_function(query_respond: List, function_name: str, *, language: str = 'python',
function_format: str = None) -> str:
"""create a function from query_respond."""
if language not in ('python'):
raise NotImplementedError(f"Not support to parse language {language}!")
# set function form
if function_format is None:
assert language == 'python', f"Expect python language for default function_format, but got {language}."
function_format = """def {function_name}():\n{code}"""
# Extract the code module in the open-interpreter respond message.
code = [item['function_call']['parsed_arguments']['code'] for item in query_respond
if "function_call" in item
and "parsed_arguments" in item["function_call"]
and 'language' in item["function_call"]['parsed_arguments']
and item["function_call"]['parsed_arguments']['language'] == language]
# add indent.
indented_code_str = textwrap.indent("\n".join(code), ' ' * 4)
# Return the code after deduplication.
if language == "python":
return extract_python_code(function_format.format(function_name=function_name, code=indented_code_str))
def gen_query(func: Callable, args, kwargs) -> str:
# Get the annotation of the function as part of the query.
desc = func.__doc__
signature = inspect.signature(func)
# Get the signature of the wrapped function and the assignment of the input parameters as part of the query.
bound_args = signature.bind(*args, **kwargs)
bound_args.apply_defaults()
query = f"{desc}, {bound_args.arguments}, If you must use a third-party package, use the most popular ones, for example: pandas, numpy, ta, ..."
return query
def gen_template_fun(func: Callable) -> str:
return f"def {func.__name__}{str(inspect.signature(func))}\n # here is your code ..."
class OpenInterpreterDecorator(object):
def __init__(self, save_code: bool = False, code_file_path: str = None, clear_code: bool = False) -> None:
self.save_code = save_code
self.code_file_path = code_file_path
self.clear_code = clear_code
def __call__(self, wrapped):
@wrapt.decorator
async def wrapper(wrapped: Callable, instance, args, kwargs):
# Get the decorated function name.
func_name = wrapped.__name__
# If the script exists locally and clearcode is not required, execute the function from the script.
if Path(self.code_file_path).is_file() and not self.clear_code:
return run_function_script(self.code_file_path, func_name, *args, **kwargs)
# Auto run generate code by using open-interpreter.
interpreter = OpenCodeInterpreter()
query = gen_query(wrapped, args, kwargs)
logger.info(f"query for OpenCodeInterpreter: \n {query}")
respond = interpreter.chat(query)
# Assemble the code blocks generated by open-interpreter into a function without parameters.
func_code = interpreter.extract_function(respond, func_name)
# Clone the `func_code` into wrapped, that is,
# keep the `func_code` and wrapped functions with the same input parameter and return value types.
template_func = gen_template_fun(wrapped)
cf = CloneFunction()
code = await cf.run(template_func=template_func, source_code=func_code)
# Display the generated function in the terminal.
logger_code = highlight(code, "python")
logger.info(f"Creating following Python function:\n{logger_code}")
# execute this function.
try:
res = run_function_code(code, func_name, *args, **kwargs)
if self.save_code:
cf._save(self.code_file_path, code)
except Exception as e:
raise Exception("Could not evaluate Python code", e)
return res
return wrapper(wrapped)

View file

@ -0,0 +1,25 @@
# 添加代码语法高亮显示
from pygments import highlight as highlight_
from pygments.lexers import PythonLexer, SqlLexer
from pygments.formatters import TerminalFormatter, HtmlFormatter
def highlight(code: str, language: str = 'python', formatter: str = 'terminal'):
# 指定要高亮的语言
if language.lower() == 'python':
lexer = PythonLexer()
elif language.lower() == 'sql':
lexer = SqlLexer()
else:
raise ValueError(f"Unsupported language: {language}")
# 指定输出格式
if formatter.lower() == 'terminal':
formatter = TerminalFormatter()
elif formatter.lower() == 'html':
formatter = HtmlFormatter()
else:
raise ValueError(f"Unsupported formatter: {formatter}")
# 使用 Pygments 高亮代码片段
return highlight_(code, lexer, formatter)

View file

@ -38,4 +38,5 @@ typing-inspect==0.8.0
typing_extensions==4.5.0
libcst==1.0.1
qdrant-client==1.4.0
semantic-kernel==0.3.10.dev0
open-interpreter==0.1.3
ta==0.10.2

View file

@ -0,0 +1,54 @@
import pytest
from metagpt.actions.clone_function import CloneFunction, run_function_code
source_code = """
import pandas as pd
import ta
def user_indicator():
# 读取股票数据
stock_data = pd.read_csv('./tests/data/baba_stock.csv')
stock_data.head()
# 计算简单移动平均线
stock_data['SMA'] = ta.trend.sma_indicator(stock_data['Close'], window=6)
stock_data[['Date', 'Close', 'SMA']].head()
# 计算布林带
stock_data['bb_upper'], stock_data['bb_middle'], stock_data['bb_lower'] = ta.volatility.bollinger_hband_indicator(stock_data['Close'], window=20), ta.volatility.bollinger_mavg(stock_data['Close'], window=20), ta.volatility.bollinger_lband_indicator(stock_data['Close'], window=20)
stock_data[['Date', 'Close', 'bb_upper', 'bb_middle', 'bb_lower']].head()
"""
template_code = """
def stock_indicator(stock_path: str, indicators=['Simple Moving Average', 'BollingerBands', 'MACD]) -> pd.DataFrame:
import pandas as pd
# here is your code.
"""
def get_expected_res():
import pandas as pd
import ta
# 读取股票数据
stock_data = pd.read_csv('./tests/data/baba_stock.csv')
stock_data.head()
# 计算简单移动平均线
stock_data['SMA'] = ta.trend.sma_indicator(stock_data['Close'], window=6)
stock_data[['Date', 'Close', 'SMA']].head()
# 计算布林带
stock_data['bb_upper'], stock_data['bb_middle'], stock_data['bb_lower'] = ta.volatility.bollinger_hband_indicator(stock_data['Close'], window=20), ta.volatility.bollinger_mavg(stock_data['Close'], window=20), ta.volatility.bollinger_lband_indicator(stock_data['Close'], window=20)
stock_data[['Date', 'Close', 'bb_upper', 'bb_middle', 'bb_lower']].head()
return stock_data
@pytest.mark.asyncio
async def test_clone_function():
clone = CloneFunction()
code = await clone.run(template_code, source_code)
assert 'def ' in code
stock_path = './tests/data/baba_stock.csv'
df, msg = run_function_code(code, 'stock_indicator', stock_path)
assert not msg
expected_df = get_expected_res()
assert df.equals(expected_df)

View file

@ -0,0 +1,42 @@
import pytest
import pandas as pd
from pathlib import Path
from tests.data import sales_desc, store_desc
from metagpt.tools.code_interpreter import OpenCodeInterpreter, OpenInterpreterDecorator
from metagpt.actions import Action
from metagpt.logs import logger
logger.add('./tests/data/test_ci.log')
stock = "./tests/data/baba_stock.csv"
# TODO: 需要一种表格数据格式能够支持schame管理的标注字段类型和字段含义。
class CreateStockIndicators(Action):
@OpenInterpreterDecorator(save_code=True, code_file_path="./tests/data/stock_indicators.py")
async def run(self, stock_path: str, indicators=['Simple Moving Average', 'BollingerBands']) -> pd.DataFrame:
"""对stock_path中的股票数据, 使用pandas和ta计算indicators中的技术指标, 返回带有技术指标的股票数据,不需要去除空值, 不需要安装任何包;
指标生成对应的三列: SMA, BB_upper, BB_lower
"""
...
@pytest.mark.asyncio
async def test_actions():
# 计算指标
indicators = ['Simple Moving Average', 'BollingerBands']
stocker = CreateStockIndicators()
df, msg = await stocker.run(stock, indicators=indicators)
assert isinstance(df, pd.DataFrame)
assert 'Close' in df.columns
assert 'Date' in df.columns
# 将df保存为文件将文件路径传入到下一个action
df_path = './tests/data/stock_indicators.csv'
df.to_csv(df_path)
assert Path(df_path).is_file()
# 可视化指标结果
figure_path = './tests/data/figure_ci.png'
ci_ploter = OpenCodeInterpreter()
ci_ploter.chat(f"使用seaborn对{df_path}中与股票布林带有关的数据列的Date, Close, SMA, BB_upper布林带上界, BB_lower布林带下界进行可视化, 可视化图片保存在{figure_path}中。不需要任何指标计算把Date列转换为日期类型。要求图片优美BB_upper, BB_lower之间使用合适的颜色填充。")
assert Path(figure_path).is_file()