diff --git a/metagpt/actions/clone_function.py b/metagpt/actions/clone_function.py new file mode 100644 index 000000000..377c35de4 --- /dev/null +++ b/metagpt/actions/clone_function.py @@ -0,0 +1,66 @@ +from pathlib import Path +import traceback + +from metagpt.actions.write_code import WriteCode +from metagpt.logs import logger +from metagpt.schema import Message +from metagpt.utils.highlight import highlight + +CLONE_PROMPT = """ +*context* +Please convert the function code ```{source_code}``` into the the function format: ```{template_func}```. +*Please Write code based on the following list and context* +1. Write code start with ```, and end with ```. +2. Please implement it in one function if possible, except for import statements. for exmaple: +```python +import pandas as pd +def run(*args) -> pd.DataFrame: + ... +``` +3. Do not use public member functions that do not exist in your design. +4. The output function name, input parameters and return value must be the same as ```{template_func}```. +5. Make sure the results before and after the code conversion are required to be exactly the same. +6. Don't repeat my context in your replies. +7. Return full results, for example, if the return value has df.head(), please return df. +8. If you must use a third-party package, use the most popular ones, for example: pandas, numpy, ta, ... +""" + + +class CloneFunction(WriteCode): + def __init__(self, name="CloneFunction", context: list[Message] = None, llm=None): + super().__init__(name, context, llm) + + def _save(self, code_path, code): + if isinstance(code_path, str): + code_path = Path(code_path) + code_path.parent.mkdir(parents=True, exist_ok=True) + code_path.write_text(code) + logger.info(f"Saving Code to {code_path}") + + async def run(self, template_func: str, source_code: str) -> str: + """将source_code转换成template_func一样的入参和返回类型""" + prompt = CLONE_PROMPT.format(source_code=source_code, template_func=template_func) + logger.info(f"query for CloneFunction: \n {prompt}") + code = await self.write_code(prompt) + assert 'def' in code + logger.info(f'CloneFunction code is \n {highlight(code)}') + return code + + +def run_fucntion_code(func_code: str, func_name: str, *args, **kwargs): + """执行函数类生成代码""" + try: + locals_ = {} + exec(func_code, locals_) + func = locals_[func_name] + return func(*args, **kwargs), "" + except Exception: + return "", traceback.format_exc() + + +def run_fucntion_script(code_script_path: str, func_name: str, *args, **kwargs): + """从脚本中载入函数进行执行""" + if isinstance(code_script_path, str): + code_path = Path(code_script_path) + code = code_path.read_text(encoding='utf-8') + return run_fucntion_code(code, func_name, *args, **kwargs) diff --git a/metagpt/tools/code_interpreter.py b/metagpt/tools/code_interpreter.py new file mode 100644 index 000000000..2db867222 --- /dev/null +++ b/metagpt/tools/code_interpreter.py @@ -0,0 +1,128 @@ +import re +from typing import List, Callable +from pathlib import Path + +import wrapt +import textwrap +import inspect +from interpreter.interpreter import Interpreter + +from metagpt.logs import logger +from metagpt.config import CONFIG +from metagpt.utils.highlight import highlight +from metagpt.actions.clone_function import CloneFunction, run_fucntion_code, run_fucntion_script +from metagpt.actions.run_code import RunCode + + +def extract_python_code(code: str): + """提取代码块: 如果代码注释相同,则只保留最后一个代码块.""" + # 使用正则表达式匹配注释块和相关的代码 + pattern = r'(#\s[^\n]*)\n(.*?)(?=\n\s*#|$)' + matches = re.findall(pattern, code, re.DOTALL) + + # 提取每个相同注释的最后一个代码块 + unique_comments = {} + for comment, code_block in matches: + unique_comments[comment] = code_block + + # 组装结果字符串 + result_code = '\n'.join([f"{comment}\n{code_block}" for comment, code_block in unique_comments.items()]) + header_code = code[:code.find("#")] + code = header_code + result_code + + logger.info(f"Extract python code: \n {highlight(code)}") + + return code + + +class OpenCodeInterpreter(object): + """https://github.com/KillianLucas/open-interpreter""" + def __init__(self, auto_run: bool = True) -> None: + interpreter = Interpreter() + interpreter.auto_run = auto_run + interpreter.model = CONFIG.openai_api_model or "gpt-3.5-turbo" + interpreter.api_key = CONFIG.openai_api_key + self.interpreter = interpreter + + def chat(self, query: str, reset: bool = True): + if reset: + self.interpreter.reset() + return self.interpreter.chat(query, return_messages=True) + + @staticmethod + def extract_function(query_respond: List, function_name: str, *, language: str = 'python', + function_format: str = None) -> str: + """create a function from query_respond.""" + if language not in ('python'): + raise NotImplementedError(f"Not support to parse language {language}!") + + # 定义函数形式 + if function_format is None: + assert language == 'python', f"Expect python language for default function_format, but got {language}." + function_format = """def {function_name}():\n{code}""" + # 解析open-interpreter respond message中的代码模块 + code = [item['function_call']['parsed_arguments']['code'] for item in query_respond + if "function_call" in item + and "parsed_arguments" in item["function_call"] + and 'language' in item["function_call"]['parsed_arguments'] + and item["function_call"]['parsed_arguments']['language'] == language] + # 添加缩进 + indented_code_str = textwrap.indent("\n".join(code), ' ' * 4) + # 按照代码注释, 返回去重后的代码 + if language == "python": + return extract_python_code(function_format.format(function_name=function_name, code=indented_code_str)) + + +def gen_query(func: Callable, args, kwargs) -> str: + # 函数的注释, 也就是query的主体 + desc = func.__doc__ + signature = inspect.signature(func) + # 获取函数wrapped的签名和入参的赋值 + bound_args = signature.bind(*args, **kwargs) + bound_args.apply_defaults() + query = f"{desc}, {bound_args.arguments}, If you must use a third-party package, use the most popular ones, for example: pandas, numpy, ta, ..." + return query + + +def gen_template_fun(func: Callable) -> str: + return f"def {func.__name__}{str(inspect.signature(func))}\n # here is your code ..." + + +class OpenInterpreterDecorator(object): + def __init__(self, save_code: bool = False, code_file_path: str = None, clear_code: bool = False) -> None: + self.save_code = save_code + self.code_file_path = code_file_path + self.clear_code = clear_code + + def __call__(self, wrapped): + @wrapt.decorator + async def wrapper(wrapped: Callable, instance, args, kwargs): + # 获取被装饰的函数名 + func_name = wrapped.__name__ + # 如果脚本在本地存在,而且不需要clearcode,则从脚本执行该函数 + if Path(self.code_file_path).is_file() and not self.clear_code: + return run_fucntion_script(self.code_file_path, func_name, *args, **kwargs) + + # 使用open-interpreter逐步生成代码 + interpreter = OpenCodeInterpreter() + query = gen_query(wrapped, args, kwargs) + logger.info(f"query for OpenCodeInterpreter: \n {query}") + respond = interpreter.chat(query) + # 将open-interpreter逐步生成的代码组装成无入参的函数 + func_code = interpreter.extract_function(respond, func_name) + # 把code克隆为wrapped,即保持code和wrapped函数有相同的入参和返回值类型 + template_func = gen_template_fun(wrapped) + cf = CloneFunction() + code = await cf.run(template_func=template_func, source_code=func_code) + # 在终端显示生成的函数 + logger_code = highlight(code, "python") + logger.info(f"Creating following Python function:\n{logger_code}") + # 执行该函数 + try: + res = run_fucntion_code(code, func_name, *args, **kwargs) + if self.save_code: + cf._save(self.code_file_path, code) + except Exception as e: + raise Exception("Could not evaluate Python code", e) + return res + return wrapper(wrapped) diff --git a/metagpt/utils/highlight.py b/metagpt/utils/highlight.py new file mode 100644 index 000000000..e6cbb228c --- /dev/null +++ b/metagpt/utils/highlight.py @@ -0,0 +1,25 @@ +# 添加代码语法高亮显示 +from pygments import highlight as highlight_ +from pygments.lexers import PythonLexer, SqlLexer +from pygments.formatters import TerminalFormatter, HtmlFormatter + + +def highlight(code: str, language: str = 'python', formatter: str = 'terminal'): + # 指定要高亮的语言 + if language.lower() == 'python': + lexer = PythonLexer() + elif language.lower() == 'sql': + lexer = SqlLexer() + else: + raise ValueError(f"Unsupported language: {language}") + + # 指定输出格式 + if formatter.lower() == 'terminal': + formatter = TerminalFormatter() + elif formatter.lower() == 'html': + formatter = HtmlFormatter() + else: + raise ValueError(f"Unsupported formatter: {formatter}") + + # 使用 Pygments 高亮代码片段 + return highlight_(code, lexer, formatter) diff --git a/tests/metagpt/actions/test_clone_function.py b/tests/metagpt/actions/test_clone_function.py new file mode 100644 index 000000000..7ac58e065 --- /dev/null +++ b/tests/metagpt/actions/test_clone_function.py @@ -0,0 +1,54 @@ +import pytest + +from metagpt.actions.clone_function import CloneFunction, run_fucntion_code + + +source_code = """ +def user_indicator(): + import pandas as pd + import ta + + # 读取股票数据 + stock_data = pd.read_csv('./tests/data/baba_stock.csv') + stock_data.head() + # 计算简单移动平均线 + stock_data['SMA'] = ta.trend.sma_indicator(stock_data['Close'], window=6) + stock_data[['Date', 'Close', 'SMA']].head() + # 计算布林带 + stock_data['bb_upper'], stock_data['bb_middle'], stock_data['bb_lower'] = ta.volatility.bollinger_hband_indicator(stock_data['Close'], window=20), ta.volatility.bollinger_mavg(stock_data['Close'], window=20), ta.volatility.bollinger_lband_indicator(stock_data['Close'], window=20) + stock_data[['Date', 'Close', 'bb_upper', 'bb_middle', 'bb_lower']].head() +""" + +template_code = """ +def stock_indicator(stock_path: str, indicators=['Simple Moving Average', 'BollingerBands', 'MACD]) -> pd.DataFrame: + import pandas as pd + # here is your code. +""" + + +def get_expected_res(): + import pandas as pd + import ta + + # 读取股票数据 + stock_data = pd.read_csv('./tests/data/baba_stock.csv') + stock_data.head() + # 计算简单移动平均线 + stock_data['SMA'] = ta.trend.sma_indicator(stock_data['Close'], window=6) + stock_data[['Date', 'Close', 'SMA']].head() + # 计算布林带 + stock_data['bb_upper'], stock_data['bb_middle'], stock_data['bb_lower'] = ta.volatility.bollinger_hband_indicator(stock_data['Close'], window=20), ta.volatility.bollinger_mavg(stock_data['Close'], window=20), ta.volatility.bollinger_lband_indicator(stock_data['Close'], window=20) + stock_data[['Date', 'Close', 'bb_upper', 'bb_middle', 'bb_lower']].head() + return stock_data + + +@pytest.mark.asyncio +async def test_clone_function(): + clone = CloneFunction() + code = await clone.run(template_code, source_code) + assert 'def ' in code + stock_path = './tests/data/baba_stock.csv' + df, msg = run_fucntion_code(code, 'stock_indicator', stock_path) + assert not msg + expected_df = get_expected_res() + assert df.equals(expected_df) diff --git a/tests/metagpt/tools/test_code_interpreter.py b/tests/metagpt/tools/test_code_interpreter.py new file mode 100644 index 000000000..0eec3f80b --- /dev/null +++ b/tests/metagpt/tools/test_code_interpreter.py @@ -0,0 +1,42 @@ +import pytest +import pandas as pd +from pathlib import Path + +from tests.data import sales_desc, store_desc +from metagpt.tools.code_interpreter import OpenCodeInterpreter, OpenInterpreterDecorator +from metagpt.actions import Action +from metagpt.logs import logger + + +logger.add('./tests/data/test_ci.log') +stock = "./tests/data/baba_stock.csv" + + +# TODO: 需要一种表格数据格式,能够支持schame管理的,标注字段类型和字段含义。 +class CreateStockIndicators(Action): + @OpenInterpreterDecorator(save_code=True, code_file_path="./tests/data/stock_indicators.py") + async def run(self, stock_path: str, indicators=['Simple Moving Average', 'BollingerBands']) -> pd.DataFrame: + """对stock_path中的股票数据, 使用pandas和ta计算indicators中的技术指标, 返回带有技术指标的股票数据,不需要去除空值, 不需要安装任何包; + 指标生成对应的三列: SMA, BB_upper, BB_lower + """ + ... + + +@pytest.mark.asyncio +async def test_actions(): + # 计算指标 + indicators = ['Simple Moving Average', 'BollingerBands'] + stocker = CreateStockIndicators() + df, msg = await stocker.run(stock, indicators=indicators) + assert isinstance(df, pd.DataFrame) + assert 'Close' in df.columns + assert 'Date' in df.columns + # 将df保存为文件,将文件路径传入到下一个action + df_path = './tests/data/stock_indicators.csv' + df.to_csv(df_path) + assert Path(df_path).is_file() + # 可视化指标结果 + figure_path = './tests/data/figure_ci.png' + ci_ploter = OpenCodeInterpreter() + ci_ploter.chat(f"使用seaborn对{df_path}中与股票布林带有关的数据列的Date, Close, SMA, BB_upper(布林带上界), BB_lower(布林带下界)进行可视化, 可视化图片保存在{figure_path}中。不需要任何指标计算,把Date列转换为日期类型。要求图片优美,BB_upper, BB_lower之间使用合适的颜色填充。") + assert Path(figure_path).is_file()