diff --git a/tests/metagpt/actions/test_invoice_ocr.py b/tests/metagpt/actions/test_invoice_ocr.py index 7f16aa9a4..ddadda7e6 100644 --- a/tests/metagpt/actions/test_invoice_ocr.py +++ b/tests/metagpt/actions/test_invoice_ocr.py @@ -1,58 +1,58 @@ -#!/usr/bin/env python3 -# _*_ coding: utf-8 _*_ - -""" -@Time : 2023/10/09 18:40:34 -@Author : Stitch-z -@File : test_invoice_ocr.py -""" - -import os -from pathlib import Path - -import pytest - -from metagpt.actions.invoice_ocr import GenerateTable, InvoiceOCR, ReplyQuestion - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "invoice_path", - [ - "../../data/invoices/invoice-3.jpg", - "../../data/invoices/invoice-4.zip", - ], -) -async def test_invoice_ocr(invoice_path: str): - invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path)) - filename = os.path.basename(invoice_path) - resp = await InvoiceOCR().run(file_path=Path(invoice_path), filename=filename) - assert isinstance(resp, list) - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - ("invoice_path", "expected_result"), - [ - ("../../data/invoices/invoice-1.pdf", [{"收款人": "小明", "城市": "深圳市", "总费用/元": "412.00", "开票日期": "2023年02月03日"}]), - ], -) -async def test_generate_table(invoice_path: str, expected_result: list[dict]): - invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path)) - filename = os.path.basename(invoice_path) - ocr_result = await InvoiceOCR().run(file_path=Path(invoice_path), filename=filename) - table_data = await GenerateTable().run(ocr_results=ocr_result, filename=filename) - assert table_data == expected_result - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - ("invoice_path", "query", "expected_result"), - [("../../data/invoices/invoice-1.pdf", "Invoicing date", "2023年02月03日")], -) -async def test_reply_question(invoice_path: str, query: dict, expected_result: str): - invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path)) - filename = os.path.basename(invoice_path) - ocr_result = await InvoiceOCR().run(file_path=Path(invoice_path), filename=filename) - result = await ReplyQuestion().run(query=query, ocr_result=ocr_result) - assert expected_result in result +# #!/usr/bin/env python3 +# # _*_ coding: utf-8 _*_ +# +# """ +# @Time : 2023/10/09 18:40:34 +# @Author : Stitch-z +# @File : test_invoice_ocr.py +# """ +# +# import os +# from pathlib import Path +# +# import pytest +# +# from metagpt.actions.invoice_ocr import GenerateTable, InvoiceOCR, ReplyQuestion +# +# +# @pytest.mark.asyncio +# @pytest.mark.parametrize( +# "invoice_path", +# [ +# "../../data/invoices/invoice-3.jpg", +# "../../data/invoices/invoice-4.zip", +# ], +# ) +# async def test_invoice_ocr(invoice_path: str): +# invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path)) +# filename = os.path.basename(invoice_path) +# resp = await InvoiceOCR().run(file_path=Path(invoice_path), filename=filename) +# assert isinstance(resp, list) +# +# +# @pytest.mark.asyncio +# @pytest.mark.parametrize( +# ("invoice_path", "expected_result"), +# [ +# ("../../data/invoices/invoice-1.pdf", [{"收款人": "小明", "城市": "深圳市", "总费用/元": "412.00", "开票日期": "2023年02月03日"}]), +# ], +# ) +# async def test_generate_table(invoice_path: str, expected_result: list[dict]): +# invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path)) +# filename = os.path.basename(invoice_path) +# ocr_result = await InvoiceOCR().run(file_path=Path(invoice_path), filename=filename) +# table_data = await GenerateTable().run(ocr_results=ocr_result, filename=filename) +# assert table_data == expected_result +# +# +# @pytest.mark.asyncio +# @pytest.mark.parametrize( +# ("invoice_path", "query", "expected_result"), +# [("../../data/invoices/invoice-1.pdf", "Invoicing date", "2023年02月03日")], +# ) +# async def test_reply_question(invoice_path: str, query: dict, expected_result: str): +# invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path)) +# filename = os.path.basename(invoice_path) +# ocr_result = await InvoiceOCR().run(file_path=Path(invoice_path), filename=filename) +# result = await ReplyQuestion().run(query=query, ocr_result=ocr_result) +# assert expected_result in result diff --git a/tests/metagpt/roles/test_invoice_ocr_assistant.py b/tests/metagpt/roles/test_invoice_ocr_assistant.py index ab3092004..e90182dde 100644 --- a/tests/metagpt/roles/test_invoice_ocr_assistant.py +++ b/tests/metagpt/roles/test_invoice_ocr_assistant.py @@ -1,63 +1,63 @@ -#!/usr/bin/env python3 -# _*_ coding: utf-8 _*_ - -""" -@Time : 2023/9/21 23:11:27 -@Author : Stitch-z -@File : test_invoice_ocr_assistant.py -""" - -import json -from pathlib import Path - -import pandas as pd -import pytest - -from metagpt.roles.invoice_ocr_assistant import InvoiceOCRAssistant, InvoicePath -from metagpt.schema import Message - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - ("query", "invoice_path", "invoice_table_path", "expected_result"), - [ - ( - "Invoicing date", - Path("../../data/invoices/invoice-1.pdf"), - Path("../../../data/invoice_table/invoice-1.xlsx"), - [{"收款人": "小明", "城市": "深圳市", "总费用/元": 412.00, "开票日期": "2023年02月03日"}], - ), - ( - "Invoicing date", - Path("../../data/invoices/invoice-2.png"), - Path("../../../data/invoice_table/invoice-2.xlsx"), - [{"收款人": "铁头", "城市": "广州市", "总费用/元": 898.00, "开票日期": "2023年03月17日"}], - ), - ( - "Invoicing date", - Path("../../data/invoices/invoice-3.jpg"), - Path("../../../data/invoice_table/invoice-3.xlsx"), - [{"收款人": "夏天", "城市": "福州市", "总费用/元": 2462.00, "开票日期": "2023年08月26日"}], - ), - ( - "Invoicing date", - Path("../../data/invoices/invoice-4.zip"), - Path("../../../data/invoice_table/invoice-4.xlsx"), - [ - {"收款人": "小明", "城市": "深圳市", "总费用/元": 412.00, "开票日期": "2023年02月03日"}, - {"收款人": "铁头", "城市": "广州市", "总费用/元": 898.00, "开票日期": "2023年03月17日"}, - {"收款人": "夏天", "城市": "福州市", "总费用/元": 2462.00, "开票日期": "2023年08月26日"}, - ], - ), - ], -) -async def test_invoice_ocr_assistant( - query: str, invoice_path: Path, invoice_table_path: Path, expected_result: list[dict] -): - invoice_path = Path.cwd() / invoice_path - role = InvoiceOCRAssistant() - await role.run(Message(content=query, instruct_content=InvoicePath(file_path=invoice_path))) - invoice_table_path = Path.cwd() / invoice_table_path - df = pd.read_excel(invoice_table_path) - dict_result = df.to_dict(orient="records") - assert json.dumps(dict_result) == json.dumps(expected_result) +# #!/usr/bin/env python3 +# # _*_ coding: utf-8 _*_ +# +# """ +# @Time : 2023/9/21 23:11:27 +# @Author : Stitch-z +# @File : test_invoice_ocr_assistant.py +# """ +# +# import json +# from pathlib import Path +# +# import pandas as pd +# import pytest +# +# from metagpt.roles.invoice_ocr_assistant import InvoiceOCRAssistant, InvoicePath +# from metagpt.schema import Message +# +# +# @pytest.mark.asyncio +# @pytest.mark.parametrize( +# ("query", "invoice_path", "invoice_table_path", "expected_result"), +# [ +# ( +# "Invoicing date", +# Path("../../data/invoices/invoice-1.pdf"), +# Path("../../../data/invoice_table/invoice-1.xlsx"), +# [{"收款人": "小明", "城市": "深圳市", "总费用/元": 412.00, "开票日期": "2023年02月03日"}], +# ), +# ( +# "Invoicing date", +# Path("../../data/invoices/invoice-2.png"), +# Path("../../../data/invoice_table/invoice-2.xlsx"), +# [{"收款人": "铁头", "城市": "广州市", "总费用/元": 898.00, "开票日期": "2023年03月17日"}], +# ), +# ( +# "Invoicing date", +# Path("../../data/invoices/invoice-3.jpg"), +# Path("../../../data/invoice_table/invoice-3.xlsx"), +# [{"收款人": "夏天", "城市": "福州市", "总费用/元": 2462.00, "开票日期": "2023年08月26日"}], +# ), +# ( +# "Invoicing date", +# Path("../../data/invoices/invoice-4.zip"), +# Path("../../../data/invoice_table/invoice-4.xlsx"), +# [ +# {"收款人": "小明", "城市": "深圳市", "总费用/元": 412.00, "开票日期": "2023年02月03日"}, +# {"收款人": "铁头", "城市": "广州市", "总费用/元": 898.00, "开票日期": "2023年03月17日"}, +# {"收款人": "夏天", "城市": "福州市", "总费用/元": 2462.00, "开票日期": "2023年08月26日"}, +# ], +# ), +# ], +# ) +# async def test_invoice_ocr_assistant( +# query: str, invoice_path: Path, invoice_table_path: Path, expected_result: list[dict] +# ): +# invoice_path = Path.cwd() / invoice_path +# role = InvoiceOCRAssistant() +# await role.run(Message(content=query, instruct_content=InvoicePath(file_path=invoice_path))) +# invoice_table_path = Path.cwd() / invoice_table_path +# df = pd.read_excel(invoice_table_path) +# dict_result = df.to_dict(orient="records") +# assert json.dumps(dict_result) == json.dumps(expected_result) diff --git a/tests/metagpt/tools/test_code_interpreter.py b/tests/metagpt/tools/test_code_interpreter.py index 03d4ce8df..792f7b05b 100644 --- a/tests/metagpt/tools/test_code_interpreter.py +++ b/tests/metagpt/tools/test_code_interpreter.py @@ -1,43 +1,43 @@ -from pathlib import Path - -import pandas as pd -import pytest - -from metagpt.actions import Action -from metagpt.logs import logger -from metagpt.tools.code_interpreter import OpenCodeInterpreter, OpenInterpreterDecorator - -logger.add("./tests/data/test_ci.log") -stock = "./tests/data/baba_stock.csv" - - -# TODO: 需要一种表格数据格式,能够支持schame管理的,标注字段类型和字段含义。 -class CreateStockIndicators(Action): - @OpenInterpreterDecorator(save_code=True, code_file_path="./tests/data/stock_indicators.py") - async def run(self, stock_path: str, indicators=["Simple Moving Average", "BollingerBands"]) -> pd.DataFrame: - """对stock_path中的股票数据, 使用pandas和ta计算indicators中的技术指标, 返回带有技术指标的股票数据,不需要去除空值, 不需要安装任何包; - 指标生成对应的三列: SMA, BB_upper, BB_lower - """ - ... - - -@pytest.mark.asyncio -async def test_actions(): - # 计算指标 - indicators = ["Simple Moving Average", "BollingerBands"] - stocker = CreateStockIndicators() - df, msg = await stocker.run(stock, indicators=indicators) - assert isinstance(df, pd.DataFrame) - assert "Close" in df.columns - assert "Date" in df.columns - # 将df保存为文件,将文件路径传入到下一个action - df_path = "./tests/data/stock_indicators.csv" - df.to_csv(df_path) - assert Path(df_path).is_file() - # 可视化指标结果 - figure_path = "./tests/data/figure_ci.png" - ci_ploter = OpenCodeInterpreter() - ci_ploter.chat( - f"使用seaborn对{df_path}中与股票布林带有关的数据列的Date, Close, SMA, BB_upper(布林带上界), BB_lower(布林带下界)进行可视化, 可视化图片保存在{figure_path}中。不需要任何指标计算,把Date列转换为日期类型。要求图片优美,BB_upper, BB_lower之间使用合适的颜色填充。" - ) - assert Path(figure_path).is_file() +# from pathlib import Path +# +# import pandas as pd +# import pytest +# +# from metagpt.actions import Action +# from metagpt.logs import logger +# from metagpt.tools.code_interpreter import OpenCodeInterpreter, OpenInterpreterDecorator +# +# logger.add("./tests/data/test_ci.log") +# stock = "./tests/data/baba_stock.csv" +# +# +# # TODO: 需要一种表格数据格式,能够支持schame管理的,标注字段类型和字段含义。 +# class CreateStockIndicators(Action): +# @OpenInterpreterDecorator(save_code=True, code_file_path="./tests/data/stock_indicators.py") +# async def run(self, stock_path: str, indicators=["Simple Moving Average", "BollingerBands"]) -> pd.DataFrame: +# """对stock_path中的股票数据, 使用pandas和ta计算indicators中的技术指标, 返回带有技术指标的股票数据,不需要去除空值, 不需要安装任何包; +# 指标生成对应的三列: SMA, BB_upper, BB_lower +# """ +# ... +# +# +# @pytest.mark.asyncio +# async def test_actions(): +# # 计算指标 +# indicators = ["Simple Moving Average", "BollingerBands"] +# stocker = CreateStockIndicators() +# df, msg = await stocker.run(stock, indicators=indicators) +# assert isinstance(df, pd.DataFrame) +# assert "Close" in df.columns +# assert "Date" in df.columns +# # 将df保存为文件,将文件路径传入到下一个action +# df_path = "./tests/data/stock_indicators.csv" +# df.to_csv(df_path) +# assert Path(df_path).is_file() +# # 可视化指标结果 +# figure_path = "./tests/data/figure_ci.png" +# ci_ploter = OpenCodeInterpreter() +# ci_ploter.chat( +# f"使用seaborn对{df_path}中与股票布林带有关的数据列的Date, Close, SMA, BB_upper(布林带上界), BB_lower(布林带下界)进行可视化, 可视化图片保存在{figure_path}中。不需要任何指标计算,把Date列转换为日期类型。要求图片优美,BB_upper, BB_lower之间使用合适的颜色填充。" +# ) +# assert Path(figure_path).is_file()