diff --git a/metagpt/actions/clone_function.py b/metagpt/actions/clone_function.py deleted file mode 100644 index 7053df97b..000000000 --- a/metagpt/actions/clone_function.py +++ /dev/null @@ -1,67 +0,0 @@ -from pathlib import Path - -from pydantic import Field - -from metagpt.actions.write_code import WriteCode -from metagpt.llm import LLM -from metagpt.logs import logger -from metagpt.provider.base_llm import BaseLLM -from metagpt.schema import Message -from metagpt.utils.exceptions import handle_exception -from metagpt.utils.highlight import highlight - -CLONE_PROMPT = """ -*context* -Please convert the function code ```{source_code}``` into the the function format: ```{template_func}```. -*Please Write code based on the following list and context* -1. Write code start with ```, and end with ```. -2. Please implement it in one function if possible, except for import statements. for exmaple: -```python -import pandas as pd -def run(*args) -> pd.DataFrame: - ... -``` -3. Do not use public member functions that do not exist in your design. -4. The output function name, input parameters and return value must be the same as ```{template_func}```. -5. Make sure the results before and after the code conversion are required to be exactly the same. -6. Don't repeat my context in your replies. -7. Return full results, for example, if the return value has df.head(), please return df. -8. If you must use a third-party package, use the most popular ones, for example: pandas, numpy, ta, ... -""" - - -class CloneFunction(WriteCode): - name: str = "CloneFunction" - context: list[Message] = [] - llm: BaseLLM = Field(default_factory=LLM) - - def _save(self, code_path, code): - if isinstance(code_path, str): - code_path = Path(code_path) - code_path.parent.mkdir(parents=True, exist_ok=True) - code_path.write_text(code, encoding="utf-8") - logger.info(f"Saving Code to {code_path}") - - async def run(self, template_func: str, source_code: str) -> str: - """将source_code转换成template_func一样的入参和返回类型""" - prompt = CLONE_PROMPT.format(source_code=source_code, template_func=template_func) - logger.info(f"query for CloneFunction: \n {prompt}") - code = await self.write_code(prompt) - logger.info(f"CloneFunction code is \n {highlight(code)}") - return code - - -@handle_exception -def run_function_code(func_code: str, func_name: str, *args, **kwargs): - """Run function code from string code.""" - locals_ = {} - exec(func_code, locals_) - func = locals_[func_name] - return func(*args, **kwargs), "" - - -def run_function_script(code_script_path: str, func_name: str, *args, **kwargs): - """Run function code from script.""" - code_path = Path(code_script_path) - code = code_path.read_text(encoding="utf-8") - return run_function_code(code, func_name, *args, **kwargs) diff --git a/metagpt/provider/zhipuai_api.py b/metagpt/provider/zhipuai_api.py index e4b066a0c..df8c330b8 100644 --- a/metagpt/provider/zhipuai_api.py +++ b/metagpt/provider/zhipuai_api.py @@ -49,6 +49,9 @@ class ZhiPuAIGPTAPI(BaseLLM): zhipuai.api_key = config.zhipuai_api_key # due to use openai sdk, set the api_key but it will't be used. # openai.api_key = zhipuai.api_key # due to use openai sdk, set the api_key but it will't be used. + if config.openai_proxy: + # FIXME: openai v1.x sdk has no proxy support + openai.proxy = config.openai_proxy def _const_kwargs(self, messages: list[dict]) -> dict: kwargs = {"model": self.model, "prompt": messages, "temperature": 0.3} diff --git a/metagpt/roles/tutorial_assistant.py b/metagpt/roles/tutorial_assistant.py index 5d1323371..bedf8b3be 100644 --- a/metagpt/roles/tutorial_assistant.py +++ b/metagpt/roles/tutorial_assistant.py @@ -90,4 +90,5 @@ class TutorialAssistant(Role): msg = await super().react() root_path = TUTORIAL_PATH / datetime.now().strftime("%Y-%m-%d_%H-%M-%S") await File.write(root_path, f"{self.main_title}.md", self.total_content.encode("utf-8")) + msg.content = str(root_path / f"{self.main_title}.md") return msg diff --git a/metagpt/tools/code_interpreter.py b/metagpt/tools/code_interpreter.py deleted file mode 100644 index 5592b0704..000000000 --- a/metagpt/tools/code_interpreter.py +++ /dev/null @@ -1,207 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -""" -@Time : -@Author : -@File : code_interpreter.py -@Warning : open-interpreter 0.1.17 requires openai<0.29.0,>=0.28.0, but you have openai 1.6.0 which is incompatible. - open-interpreter 0.1.17 requires tiktoken<0.5.0,>=0.4.0, but you have tiktoken 0.5.2 which is incompatible. -""" - -# import inspect -# import re -# import textwrap -# from pathlib import Path -# from typing import Callable, Dict, List -# -# import wrapt -# from interpreter.core.core import Interpreter -# -# from metagpt.actions.clone_function import ( -# CloneFunction, -# run_function_code, -# run_function_script, -# ) -# from metagpt.config import CONFIG -# from metagpt.logs import logger -# from metagpt.utils.highlight import highlight -# -# -# def extract_python_code(code: str): -# """Extract code blocks: If the code comments are the same, only the last code block is kept.""" -# # Use regular expressions to match comment blocks and related code. -# pattern = r"(#\s[^\n]*)\n(.*?)(?=\n\s*#|$)" -# matches = re.findall(pattern, code, re.DOTALL) -# -# # Extract the last code block when encountering the same comment. -# unique_comments = {} -# for comment, code_block in matches: -# unique_comments[comment] = code_block -# -# # concatenate into functional form -# result_code = "\n".join([f"{comment}\n{code_block}" for comment, code_block in unique_comments.items()]) -# header_code = code[: code.find("#")] -# code = header_code + result_code -# -# logger.info(f"Extract python code: \n {highlight(code)}") -# -# return code -# -# -# class OpenCodeInterpreter(object): -# """https://github.com/KillianLucas/open-interpreter""" -# -# def __init__(self, auto_run: bool = True) -> None: -# interpreter = Interpreter() -# interpreter.auto_run = auto_run -# interpreter.model = CONFIG.openai_api_model or "gpt-3.5-turbo" -# interpreter.api_key = CONFIG.openai_api_key -# self.interpreter = interpreter -# -# def chat(self, query: str, reset: bool = True): -# if reset: -# self.interpreter.reset() -# return self.interpreter.chat(query) -# -# @staticmethod -# def extract_function( -# query_respond: List, function_name: str, *, language: str = "python", function_format: str = None -# ) -> str: -# """create a function from query_respond.""" -# if language not in ("python"): -# raise NotImplementedError(f"Not support to parse language {language}!") -# -# # set function form -# if function_format is None: -# assert language == "python", f"Expect python language for default function_format, but got {language}." -# function_format = """def {function_name}():\n{code}""" -# # Extract the code module in the open-interpreter respond message. -# # The query_respond of open-interpreter before v0.1.4 is: -# # [{'role': 'user', 'content': your query string}, -# # {'role': 'assistant', 'content': plan from llm, 'function_call': { -# # "name": "run_code", "arguments": "{"language": "python", "code": code of first plan}, -# # "parsed_arguments": {"language": "python", "code": code of first plan} -# # ...] -# if "function_call" in query_respond[1]: -# code = [ -# item["function_call"]["parsed_arguments"]["code"] -# for item in query_respond -# if "function_call" in item -# and "parsed_arguments" in item["function_call"] -# and "language" in item["function_call"]["parsed_arguments"] -# and item["function_call"]["parsed_arguments"]["language"] == language -# ] -# # The query_respond of open-interpreter v0.1.7 is: -# # [{'role': 'user', 'message': your query string}, -# # {'role': 'assistant', 'message': plan from llm, 'language': 'python', -# # 'code': code of first plan, 'output': output of first plan code}, -# # ...] -# elif "code" in query_respond[1]: -# code = [ -# item["code"] -# for item in query_respond -# if "code" in item and "language" in item and item["language"] == language -# ] -# else: -# raise ValueError(f"Unexpect message format in query_respond: {query_respond[1].keys()}") -# # add indent. -# indented_code_str = textwrap.indent("\n".join(code), " " * 4) -# # Return the code after deduplication. -# if language == "python": -# return extract_python_code(function_format.format(function_name=function_name, code=indented_code_str)) -# -# -# def gen_query(func: Callable, args, kwargs) -> str: -# # Get the annotation of the function as part of the query. -# desc = func.__doc__ -# signature = inspect.signature(func) -# # Get the signature of the wrapped function and the assignment of the input parameters as part of the query. -# bound_args = signature.bind(*args, **kwargs) -# bound_args.apply_defaults() -# query = f"{desc}, {bound_args.arguments}, If you must use a third-party package, use the most popular ones, for example: pandas, numpy, ta, ..." -# return query -# -# -# def gen_template_fun(func: Callable) -> str: -# return f"def {func.__name__}{str(inspect.signature(func))}\n # here is your code ..." -# -# -# class OpenInterpreterDecorator(object): -# def __init__(self, save_code: bool = False, code_file_path: str = None, clear_code: bool = False) -> None: -# self.save_code = save_code -# self.code_file_path = code_file_path -# self.clear_code = clear_code -# -# def _have_code(self, rsp: List[Dict]): -# # Is there any code generated? -# return "code" in rsp[1] and rsp[1]["code"] not in ("", None) -# -# def _is_faild_plan(self, rsp: List[Dict]): -# # is faild plan? -# func_code = OpenCodeInterpreter.extract_function(rsp, "function") -# # If there is no more than 1 '\n', the plan execution fails. -# if isinstance(func_code, str) and func_code.count("\n") <= 1: -# return True -# return False -# -# def _check_respond(self, query: str, interpreter: OpenCodeInterpreter, respond: List[Dict], max_try: int = 3): -# for _ in range(max_try): -# # TODO: If no code or faild plan is generated, execute chat again, repeating no more than max_try times. -# if self._have_code(respond) and not self._is_faild_plan(respond): -# break -# elif not self._have_code(respond): -# logger.warning(f"llm did not return executable code, resend the query: \n{query}") -# respond = interpreter.chat(query) -# elif self._is_faild_plan(respond): -# logger.warning(f"llm did not generate successful plan, resend the query: \n{query}") -# respond = interpreter.chat(query) -# -# # Post-processing of respond -# if not self._have_code(respond): -# error_msg = f"OpenCodeInterpreter do not generate code for query: \n{query}" -# logger.error(error_msg) -# raise ValueError(error_msg) -# -# if self._is_faild_plan(respond): -# error_msg = f"OpenCodeInterpreter do not generate code for query: \n{query}" -# logger.error(error_msg) -# raise ValueError(error_msg) -# return respond -# -# def __call__(self, wrapped): -# @wrapt.decorator -# async def wrapper(wrapped: Callable, instance, args, kwargs): -# # Get the decorated function name. -# func_name = wrapped.__name__ -# # If the script exists locally and clearcode is not required, execute the function from the script. -# if self.code_file_path and Path(self.code_file_path).is_file() and not self.clear_code: -# return run_function_script(self.code_file_path, func_name, *args, **kwargs) -# -# # Auto run generate code by using open-interpreter. -# interpreter = OpenCodeInterpreter() -# query = gen_query(wrapped, args, kwargs) -# logger.info(f"query for OpenCodeInterpreter: \n {query}") -# respond = interpreter.chat(query) -# # Make sure the response is as expected. -# respond = self._check_respond(query, interpreter, respond, 3) -# # Assemble the code blocks generated by open-interpreter into a function without parameters. -# func_code = interpreter.extract_function(respond, func_name) -# # Clone the `func_code` into wrapped, that is, -# # keep the `func_code` and wrapped functions with the same input parameter and return value types. -# template_func = gen_template_fun(wrapped) -# cf = CloneFunction() -# code = await cf.run(template_func=template_func, source_code=func_code) -# # Display the generated function in the terminal. -# logger_code = highlight(code, "python") -# logger.info(f"Creating following Python function:\n{logger_code}") -# # execute this function. -# try: -# res = run_function_code(code, func_name, *args, **kwargs) -# if self.save_code and self.code_file_path: -# cf._save(self.code_file_path, code) -# except Exception as e: -# logger.error(f"Could not evaluate Python code \n{logger_code}: \nError: {e}") -# raise Exception("Could not evaluate Python code", e) -# return res -# -# return wrapper(wrapped) diff --git a/tests/metagpt/actions/test_invoice_ocr.py b/tests/metagpt/actions/test_invoice_ocr.py index ddadda7e6..b3b93cf9f 100644 --- a/tests/metagpt/actions/test_invoice_ocr.py +++ b/tests/metagpt/actions/test_invoice_ocr.py @@ -1,58 +1,61 @@ -# #!/usr/bin/env python3 -# # _*_ coding: utf-8 _*_ -# -# """ -# @Time : 2023/10/09 18:40:34 -# @Author : Stitch-z -# @File : test_invoice_ocr.py -# """ -# -# import os -# from pathlib import Path -# -# import pytest -# -# from metagpt.actions.invoice_ocr import GenerateTable, InvoiceOCR, ReplyQuestion -# -# -# @pytest.mark.asyncio -# @pytest.mark.parametrize( -# "invoice_path", -# [ -# "../../data/invoices/invoice-3.jpg", -# "../../data/invoices/invoice-4.zip", -# ], -# ) -# async def test_invoice_ocr(invoice_path: str): -# invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path)) -# filename = os.path.basename(invoice_path) -# resp = await InvoiceOCR().run(file_path=Path(invoice_path), filename=filename) -# assert isinstance(resp, list) -# -# -# @pytest.mark.asyncio -# @pytest.mark.parametrize( -# ("invoice_path", "expected_result"), -# [ -# ("../../data/invoices/invoice-1.pdf", [{"收款人": "小明", "城市": "深圳市", "总费用/元": "412.00", "开票日期": "2023年02月03日"}]), -# ], -# ) -# async def test_generate_table(invoice_path: str, expected_result: list[dict]): -# invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path)) -# filename = os.path.basename(invoice_path) -# ocr_result = await InvoiceOCR().run(file_path=Path(invoice_path), filename=filename) -# table_data = await GenerateTable().run(ocr_results=ocr_result, filename=filename) -# assert table_data == expected_result -# -# -# @pytest.mark.asyncio -# @pytest.mark.parametrize( -# ("invoice_path", "query", "expected_result"), -# [("../../data/invoices/invoice-1.pdf", "Invoicing date", "2023年02月03日")], -# ) -# async def test_reply_question(invoice_path: str, query: dict, expected_result: str): -# invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path)) -# filename = os.path.basename(invoice_path) -# ocr_result = await InvoiceOCR().run(file_path=Path(invoice_path), filename=filename) -# result = await ReplyQuestion().run(query=query, ocr_result=ocr_result) -# assert expected_result in result +#!/usr/bin/env python3 +# _*_ coding: utf-8 _*_ + +""" +@Time : 2023/10/09 18:40:34 +@Author : Stitch-z +@File : test_invoice_ocr.py +""" +import json +import os +from pathlib import Path + +import pytest + +from metagpt.actions.invoice_ocr import GenerateTable, InvoiceOCR, ReplyQuestion + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "invoice_path", + [ + "../../data/invoices/invoice-3.jpg", + "../../data/invoices/invoice-4.zip", + ], +) +async def test_invoice_ocr(invoice_path: str): + invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path)) + filename = os.path.basename(invoice_path) + resp = await InvoiceOCR().run(file_path=Path(invoice_path), filename=filename) + assert isinstance(resp, list) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("invoice_path", "expected_result"), + [ + ( + "../../data/invoices/invoice-1.pdf", + [{"收款人": "小明", "城市": "深圳市", "总费用/元": "412.00", "开票日期": "2023年02月03日"}] + ), + ], +) +async def test_generate_table(invoice_path: str, expected_result: list[dict]): + invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path)) + filename = os.path.basename(invoice_path) + ocr_result = await InvoiceOCR().run(file_path=Path(invoice_path), filename=filename) + table_data = await GenerateTable().run(ocr_results=ocr_result, filename=filename) + assert json.dumps(table_data) == json.dumps(expected_result) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("invoice_path", "query", "expected_result"), + [("../../data/invoices/invoice-1.pdf", "Invoicing date", "2023年02月03日")], +) +async def test_reply_question(invoice_path: str, query: dict, expected_result: str): + invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path)) + filename = os.path.basename(invoice_path) + ocr_result = await InvoiceOCR().run(file_path=Path(invoice_path), filename=filename) + result = await ReplyQuestion().run(query=query, ocr_result=ocr_result) + assert expected_result in result diff --git a/tests/metagpt/provider/test_zhipuai_api.py b/tests/metagpt/provider/test_zhipuai_api.py index d9cd23281..06f2cba62 100644 --- a/tests/metagpt/provider/test_zhipuai_api.py +++ b/tests/metagpt/provider/test_zhipuai_api.py @@ -51,3 +51,13 @@ async def test_zhipuai_acompletion(mocker): resp = await zhipu_gpt.aask(prompt_msg) assert resp == resp_content + + +def test_zhipuai_proxy(mocker): + import openai + + from metagpt.config import CONFIG + + CONFIG.openai_proxy = "http://127.0.0.1:8080" + _ = ZhiPuAIGPTAPI() + assert openai.proxy == CONFIG.openai_proxy diff --git a/tests/metagpt/roles/test_invoice_ocr_assistant.py b/tests/metagpt/roles/test_invoice_ocr_assistant.py index e90182dde..48abb9eb8 100644 --- a/tests/metagpt/roles/test_invoice_ocr_assistant.py +++ b/tests/metagpt/roles/test_invoice_ocr_assistant.py @@ -1,63 +1,59 @@ -# #!/usr/bin/env python3 -# # _*_ coding: utf-8 _*_ -# -# """ -# @Time : 2023/9/21 23:11:27 -# @Author : Stitch-z -# @File : test_invoice_ocr_assistant.py -# """ -# -# import json -# from pathlib import Path -# -# import pandas as pd -# import pytest -# -# from metagpt.roles.invoice_ocr_assistant import InvoiceOCRAssistant, InvoicePath -# from metagpt.schema import Message -# -# -# @pytest.mark.asyncio -# @pytest.mark.parametrize( -# ("query", "invoice_path", "invoice_table_path", "expected_result"), -# [ -# ( -# "Invoicing date", -# Path("../../data/invoices/invoice-1.pdf"), -# Path("../../../data/invoice_table/invoice-1.xlsx"), -# [{"收款人": "小明", "城市": "深圳市", "总费用/元": 412.00, "开票日期": "2023年02月03日"}], -# ), -# ( -# "Invoicing date", -# Path("../../data/invoices/invoice-2.png"), -# Path("../../../data/invoice_table/invoice-2.xlsx"), -# [{"收款人": "铁头", "城市": "广州市", "总费用/元": 898.00, "开票日期": "2023年03月17日"}], -# ), -# ( -# "Invoicing date", -# Path("../../data/invoices/invoice-3.jpg"), -# Path("../../../data/invoice_table/invoice-3.xlsx"), -# [{"收款人": "夏天", "城市": "福州市", "总费用/元": 2462.00, "开票日期": "2023年08月26日"}], -# ), -# ( -# "Invoicing date", -# Path("../../data/invoices/invoice-4.zip"), -# Path("../../../data/invoice_table/invoice-4.xlsx"), -# [ -# {"收款人": "小明", "城市": "深圳市", "总费用/元": 412.00, "开票日期": "2023年02月03日"}, -# {"收款人": "铁头", "城市": "广州市", "总费用/元": 898.00, "开票日期": "2023年03月17日"}, -# {"收款人": "夏天", "城市": "福州市", "总费用/元": 2462.00, "开票日期": "2023年08月26日"}, -# ], -# ), -# ], -# ) -# async def test_invoice_ocr_assistant( -# query: str, invoice_path: Path, invoice_table_path: Path, expected_result: list[dict] -# ): -# invoice_path = Path.cwd() / invoice_path -# role = InvoiceOCRAssistant() -# await role.run(Message(content=query, instruct_content=InvoicePath(file_path=invoice_path))) -# invoice_table_path = Path.cwd() / invoice_table_path -# df = pd.read_excel(invoice_table_path) -# dict_result = df.to_dict(orient="records") -# assert json.dumps(dict_result) == json.dumps(expected_result) +#!/usr/bin/env python3 +# _*_ coding: utf-8 _*_ + +""" +@Time : 2023/9/21 23:11:27 +@Author : Stitch-z +@File : test_invoice_ocr_assistant.py +""" + +from pathlib import Path + +import pandas as pd +import pytest + +from metagpt.roles.invoice_ocr_assistant import InvoiceOCRAssistant, InvoicePath +from metagpt.schema import Message + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("query", "invoice_path", "invoice_table_path", "expected_result"), + [ + ( + "Invoicing date", + Path("../../data/invoices/invoice-1.pdf"), + Path("../../../data/invoice_table/invoice-1.xlsx"), + {"收款人": "小明", "城市": "深圳", "总费用/元": 412.00, "开票日期": "2023年02月03日"}, + ), + ( + "Invoicing date", + Path("../../data/invoices/invoice-2.png"), + Path("../../../data/invoice_table/invoice-2.xlsx"), + {"收款人": "铁头", "城市": "广州", "总费用/元": 898.00, "开票日期": "2023年03月17日"}, + ), + ( + "Invoicing date", + Path("../../data/invoices/invoice-3.jpg"), + Path("../../../data/invoice_table/invoice-3.xlsx"), + {"收款人": "夏天", "城市": "福州", "总费用/元": 2462.00, "开票日期": "2023年08月26日"}, + ) + ], +) +async def test_invoice_ocr_assistant( + query: str, invoice_path: Path, invoice_table_path: Path, expected_result: dict +): + invoice_path = Path.cwd() / invoice_path + role = InvoiceOCRAssistant() + await role.run(Message(content=query, instruct_content=InvoicePath(file_path=invoice_path))) + invoice_table_path = Path.cwd() / invoice_table_path + df = pd.read_excel(invoice_table_path) + resp = df.to_dict(orient="records") + assert isinstance(resp, list) + assert len(resp) == 1 + resp = resp[0] + assert expected_result["收款人"] == resp["收款人"] + assert expected_result["城市"] in resp["城市"] + assert int(expected_result["总费用/元"]) == int(resp["总费用/元"]) + assert expected_result["开票日期"] == resp["开票日期"] + diff --git a/tests/metagpt/roles/test_tutorial_assistant.py b/tests/metagpt/roles/test_tutorial_assistant.py index 3158a5fc1..4455e1bf6 100644 --- a/tests/metagpt/roles/test_tutorial_assistant.py +++ b/tests/metagpt/roles/test_tutorial_assistant.py @@ -6,7 +6,7 @@ @File : test_tutorial_assistant.py """ import shutil - +import aiofiles import pytest from metagpt.const import TUTORIAL_PATH @@ -14,20 +14,17 @@ from metagpt.roles.tutorial_assistant import TutorialAssistant @pytest.mark.asyncio -@pytest.mark.parametrize(("language", "topic"), [("Chinese", "Write a tutorial about Python")]) +@pytest.mark.parametrize(("language", "topic"), [("Chinese", "Write a tutorial about pip")]) async def test_tutorial_assistant(language: str, topic: str): shutil.rmtree(path=TUTORIAL_PATH, ignore_errors=True) - topic = "Write a tutorial about MySQL" role = TutorialAssistant(language=language) msg = await role.run(topic) - assert "MySQL" in msg.content assert TUTORIAL_PATH.exists() - # filename = msg.content - # title = filename.split("/")[-1].split(".")[0] - # async with aiofiles.open(filename, mode="r") as reader: - # content = await reader.read() - # assert content.startswith(f"# {title}") + filename = msg.content + async with aiofiles.open(filename, mode="r", encoding="utf-8") as reader: + content = await reader.read() + assert "pip" in content if __name__ == "__main__": diff --git a/tests/metagpt/tools/test_code_interpreter.py b/tests/metagpt/tools/test_code_interpreter.py deleted file mode 100644 index 71df6315b..000000000 --- a/tests/metagpt/tools/test_code_interpreter.py +++ /dev/null @@ -1,53 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -""" -@Time : -@Author : -@File : test_code_interpreter.py -@Warning : open-interpreter 0.1.17 requires openai<0.29.0,>=0.28.0, but you have openai 1.6.0 which is incompatible. - open-interpreter 0.1.17 requires tiktoken<0.5.0,>=0.4.0, but you have tiktoken 0.5.2 which is incompatible. -""" - -# from pathlib import Path -# -# import pandas as pd -# import pytest -# -# from metagpt.actions import Action -# from metagpt.logs import logger -# from metagpt.tools.code_interpreter import OpenCodeInterpreter, OpenInterpreterDecorator -# -# logger.add("./tests/data/test_ci.log") -# stock = "./tests/data/baba_stock.csv" -# -# -# # TODO: 需要一种表格数据格式,能够支持schame管理的,标注字段类型和字段含义。 -# class CreateStockIndicators(Action): -# @OpenInterpreterDecorator(save_code=True, code_file_path="./tests/data/stock_indicators.py") -# async def run(self, stock_path: str, indicators=["Simple Moving Average", "BollingerBands"]) -> pd.DataFrame: -# """对stock_path中的股票数据, 使用pandas和ta计算indicators中的技术指标, 返回带有技术指标的股票数据,不需要去除空值, 不需要安装任何包; -# 指标生成对应的三列: SMA, BB_upper, BB_lower -# """ -# ... -# -# -# @pytest.mark.asyncio -# async def test_actions(): -# # 计算指标 -# indicators = ["Simple Moving Average", "BollingerBands"] -# stocker = CreateStockIndicators() -# df, msg = await stocker.run(stock, indicators=indicators) -# assert isinstance(df, pd.DataFrame) -# assert "Close" in df.columns -# assert "Date" in df.columns -# # 将df保存为文件,将文件路径传入到下一个action -# df_path = "./tests/data/stock_indicators.csv" -# df.to_csv(df_path) -# assert Path(df_path).is_file() -# # 可视化指标结果 -# figure_path = "./tests/data/figure_ci.png" -# ci_ploter = OpenCodeInterpreter() -# ci_ploter.chat( -# f"使用seaborn对{df_path}中与股票布林带有关的数据列的Date, Close, SMA, BB_upper(布林带上界), BB_lower(布林带下界)进行可视化, 可视化图片保存在{figure_path}中。不需要任何指标计算,把Date列转换为日期类型。要求图片优美,BB_upper, BB_lower之间使用合适的颜色填充。" -# ) -# assert Path(figure_path).is_file()