mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-02 14:45:17 +02:00
comment interpreter & ocr files for test
This commit is contained in:
parent
9531dbf3ff
commit
029884590f
3 changed files with 164 additions and 164 deletions
|
|
@ -1,58 +1,58 @@
|
|||
#!/usr/bin/env python3
|
||||
# _*_ coding: utf-8 _*_
|
||||
|
||||
"""
|
||||
@Time : 2023/10/09 18:40:34
|
||||
@Author : Stitch-z
|
||||
@File : test_invoice_ocr.py
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.actions.invoice_ocr import GenerateTable, InvoiceOCR, ReplyQuestion
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"invoice_path",
|
||||
[
|
||||
"../../data/invoices/invoice-3.jpg",
|
||||
"../../data/invoices/invoice-4.zip",
|
||||
],
|
||||
)
|
||||
async def test_invoice_ocr(invoice_path: str):
|
||||
invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path))
|
||||
filename = os.path.basename(invoice_path)
|
||||
resp = await InvoiceOCR().run(file_path=Path(invoice_path), filename=filename)
|
||||
assert isinstance(resp, list)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
("invoice_path", "expected_result"),
|
||||
[
|
||||
("../../data/invoices/invoice-1.pdf", [{"收款人": "小明", "城市": "深圳市", "总费用/元": "412.00", "开票日期": "2023年02月03日"}]),
|
||||
],
|
||||
)
|
||||
async def test_generate_table(invoice_path: str, expected_result: list[dict]):
|
||||
invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path))
|
||||
filename = os.path.basename(invoice_path)
|
||||
ocr_result = await InvoiceOCR().run(file_path=Path(invoice_path), filename=filename)
|
||||
table_data = await GenerateTable().run(ocr_results=ocr_result, filename=filename)
|
||||
assert table_data == expected_result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
("invoice_path", "query", "expected_result"),
|
||||
[("../../data/invoices/invoice-1.pdf", "Invoicing date", "2023年02月03日")],
|
||||
)
|
||||
async def test_reply_question(invoice_path: str, query: dict, expected_result: str):
|
||||
invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path))
|
||||
filename = os.path.basename(invoice_path)
|
||||
ocr_result = await InvoiceOCR().run(file_path=Path(invoice_path), filename=filename)
|
||||
result = await ReplyQuestion().run(query=query, ocr_result=ocr_result)
|
||||
assert expected_result in result
|
||||
# #!/usr/bin/env python3
|
||||
# # _*_ coding: utf-8 _*_
|
||||
#
|
||||
# """
|
||||
# @Time : 2023/10/09 18:40:34
|
||||
# @Author : Stitch-z
|
||||
# @File : test_invoice_ocr.py
|
||||
# """
|
||||
#
|
||||
# import os
|
||||
# from pathlib import Path
|
||||
#
|
||||
# import pytest
|
||||
#
|
||||
# from metagpt.actions.invoice_ocr import GenerateTable, InvoiceOCR, ReplyQuestion
|
||||
#
|
||||
#
|
||||
# @pytest.mark.asyncio
|
||||
# @pytest.mark.parametrize(
|
||||
# "invoice_path",
|
||||
# [
|
||||
# "../../data/invoices/invoice-3.jpg",
|
||||
# "../../data/invoices/invoice-4.zip",
|
||||
# ],
|
||||
# )
|
||||
# async def test_invoice_ocr(invoice_path: str):
|
||||
# invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path))
|
||||
# filename = os.path.basename(invoice_path)
|
||||
# resp = await InvoiceOCR().run(file_path=Path(invoice_path), filename=filename)
|
||||
# assert isinstance(resp, list)
|
||||
#
|
||||
#
|
||||
# @pytest.mark.asyncio
|
||||
# @pytest.mark.parametrize(
|
||||
# ("invoice_path", "expected_result"),
|
||||
# [
|
||||
# ("../../data/invoices/invoice-1.pdf", [{"收款人": "小明", "城市": "深圳市", "总费用/元": "412.00", "开票日期": "2023年02月03日"}]),
|
||||
# ],
|
||||
# )
|
||||
# async def test_generate_table(invoice_path: str, expected_result: list[dict]):
|
||||
# invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path))
|
||||
# filename = os.path.basename(invoice_path)
|
||||
# ocr_result = await InvoiceOCR().run(file_path=Path(invoice_path), filename=filename)
|
||||
# table_data = await GenerateTable().run(ocr_results=ocr_result, filename=filename)
|
||||
# assert table_data == expected_result
|
||||
#
|
||||
#
|
||||
# @pytest.mark.asyncio
|
||||
# @pytest.mark.parametrize(
|
||||
# ("invoice_path", "query", "expected_result"),
|
||||
# [("../../data/invoices/invoice-1.pdf", "Invoicing date", "2023年02月03日")],
|
||||
# )
|
||||
# async def test_reply_question(invoice_path: str, query: dict, expected_result: str):
|
||||
# invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path))
|
||||
# filename = os.path.basename(invoice_path)
|
||||
# ocr_result = await InvoiceOCR().run(file_path=Path(invoice_path), filename=filename)
|
||||
# result = await ReplyQuestion().run(query=query, ocr_result=ocr_result)
|
||||
# assert expected_result in result
|
||||
|
|
|
|||
|
|
@ -1,63 +1,63 @@
|
|||
#!/usr/bin/env python3
|
||||
# _*_ coding: utf-8 _*_
|
||||
|
||||
"""
|
||||
@Time : 2023/9/21 23:11:27
|
||||
@Author : Stitch-z
|
||||
@File : test_invoice_ocr_assistant.py
|
||||
"""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
import pytest
|
||||
|
||||
from metagpt.roles.invoice_ocr_assistant import InvoiceOCRAssistant, InvoicePath
|
||||
from metagpt.schema import Message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
("query", "invoice_path", "invoice_table_path", "expected_result"),
|
||||
[
|
||||
(
|
||||
"Invoicing date",
|
||||
Path("../../data/invoices/invoice-1.pdf"),
|
||||
Path("../../../data/invoice_table/invoice-1.xlsx"),
|
||||
[{"收款人": "小明", "城市": "深圳市", "总费用/元": 412.00, "开票日期": "2023年02月03日"}],
|
||||
),
|
||||
(
|
||||
"Invoicing date",
|
||||
Path("../../data/invoices/invoice-2.png"),
|
||||
Path("../../../data/invoice_table/invoice-2.xlsx"),
|
||||
[{"收款人": "铁头", "城市": "广州市", "总费用/元": 898.00, "开票日期": "2023年03月17日"}],
|
||||
),
|
||||
(
|
||||
"Invoicing date",
|
||||
Path("../../data/invoices/invoice-3.jpg"),
|
||||
Path("../../../data/invoice_table/invoice-3.xlsx"),
|
||||
[{"收款人": "夏天", "城市": "福州市", "总费用/元": 2462.00, "开票日期": "2023年08月26日"}],
|
||||
),
|
||||
(
|
||||
"Invoicing date",
|
||||
Path("../../data/invoices/invoice-4.zip"),
|
||||
Path("../../../data/invoice_table/invoice-4.xlsx"),
|
||||
[
|
||||
{"收款人": "小明", "城市": "深圳市", "总费用/元": 412.00, "开票日期": "2023年02月03日"},
|
||||
{"收款人": "铁头", "城市": "广州市", "总费用/元": 898.00, "开票日期": "2023年03月17日"},
|
||||
{"收款人": "夏天", "城市": "福州市", "总费用/元": 2462.00, "开票日期": "2023年08月26日"},
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
async def test_invoice_ocr_assistant(
|
||||
query: str, invoice_path: Path, invoice_table_path: Path, expected_result: list[dict]
|
||||
):
|
||||
invoice_path = Path.cwd() / invoice_path
|
||||
role = InvoiceOCRAssistant()
|
||||
await role.run(Message(content=query, instruct_content=InvoicePath(file_path=invoice_path)))
|
||||
invoice_table_path = Path.cwd() / invoice_table_path
|
||||
df = pd.read_excel(invoice_table_path)
|
||||
dict_result = df.to_dict(orient="records")
|
||||
assert json.dumps(dict_result) == json.dumps(expected_result)
|
||||
# #!/usr/bin/env python3
|
||||
# # _*_ coding: utf-8 _*_
|
||||
#
|
||||
# """
|
||||
# @Time : 2023/9/21 23:11:27
|
||||
# @Author : Stitch-z
|
||||
# @File : test_invoice_ocr_assistant.py
|
||||
# """
|
||||
#
|
||||
# import json
|
||||
# from pathlib import Path
|
||||
#
|
||||
# import pandas as pd
|
||||
# import pytest
|
||||
#
|
||||
# from metagpt.roles.invoice_ocr_assistant import InvoiceOCRAssistant, InvoicePath
|
||||
# from metagpt.schema import Message
|
||||
#
|
||||
#
|
||||
# @pytest.mark.asyncio
|
||||
# @pytest.mark.parametrize(
|
||||
# ("query", "invoice_path", "invoice_table_path", "expected_result"),
|
||||
# [
|
||||
# (
|
||||
# "Invoicing date",
|
||||
# Path("../../data/invoices/invoice-1.pdf"),
|
||||
# Path("../../../data/invoice_table/invoice-1.xlsx"),
|
||||
# [{"收款人": "小明", "城市": "深圳市", "总费用/元": 412.00, "开票日期": "2023年02月03日"}],
|
||||
# ),
|
||||
# (
|
||||
# "Invoicing date",
|
||||
# Path("../../data/invoices/invoice-2.png"),
|
||||
# Path("../../../data/invoice_table/invoice-2.xlsx"),
|
||||
# [{"收款人": "铁头", "城市": "广州市", "总费用/元": 898.00, "开票日期": "2023年03月17日"}],
|
||||
# ),
|
||||
# (
|
||||
# "Invoicing date",
|
||||
# Path("../../data/invoices/invoice-3.jpg"),
|
||||
# Path("../../../data/invoice_table/invoice-3.xlsx"),
|
||||
# [{"收款人": "夏天", "城市": "福州市", "总费用/元": 2462.00, "开票日期": "2023年08月26日"}],
|
||||
# ),
|
||||
# (
|
||||
# "Invoicing date",
|
||||
# Path("../../data/invoices/invoice-4.zip"),
|
||||
# Path("../../../data/invoice_table/invoice-4.xlsx"),
|
||||
# [
|
||||
# {"收款人": "小明", "城市": "深圳市", "总费用/元": 412.00, "开票日期": "2023年02月03日"},
|
||||
# {"收款人": "铁头", "城市": "广州市", "总费用/元": 898.00, "开票日期": "2023年03月17日"},
|
||||
# {"收款人": "夏天", "城市": "福州市", "总费用/元": 2462.00, "开票日期": "2023年08月26日"},
|
||||
# ],
|
||||
# ),
|
||||
# ],
|
||||
# )
|
||||
# async def test_invoice_ocr_assistant(
|
||||
# query: str, invoice_path: Path, invoice_table_path: Path, expected_result: list[dict]
|
||||
# ):
|
||||
# invoice_path = Path.cwd() / invoice_path
|
||||
# role = InvoiceOCRAssistant()
|
||||
# await role.run(Message(content=query, instruct_content=InvoicePath(file_path=invoice_path)))
|
||||
# invoice_table_path = Path.cwd() / invoice_table_path
|
||||
# df = pd.read_excel(invoice_table_path)
|
||||
# dict_result = df.to_dict(orient="records")
|
||||
# assert json.dumps(dict_result) == json.dumps(expected_result)
|
||||
|
|
|
|||
|
|
@ -1,43 +1,43 @@
|
|||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
import pytest
|
||||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.logs import logger
|
||||
from metagpt.tools.code_interpreter import OpenCodeInterpreter, OpenInterpreterDecorator
|
||||
|
||||
logger.add("./tests/data/test_ci.log")
|
||||
stock = "./tests/data/baba_stock.csv"
|
||||
|
||||
|
||||
# TODO: 需要一种表格数据格式,能够支持schame管理的,标注字段类型和字段含义。
|
||||
class CreateStockIndicators(Action):
|
||||
@OpenInterpreterDecorator(save_code=True, code_file_path="./tests/data/stock_indicators.py")
|
||||
async def run(self, stock_path: str, indicators=["Simple Moving Average", "BollingerBands"]) -> pd.DataFrame:
|
||||
"""对stock_path中的股票数据, 使用pandas和ta计算indicators中的技术指标, 返回带有技术指标的股票数据,不需要去除空值, 不需要安装任何包;
|
||||
指标生成对应的三列: SMA, BB_upper, BB_lower
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_actions():
|
||||
# 计算指标
|
||||
indicators = ["Simple Moving Average", "BollingerBands"]
|
||||
stocker = CreateStockIndicators()
|
||||
df, msg = await stocker.run(stock, indicators=indicators)
|
||||
assert isinstance(df, pd.DataFrame)
|
||||
assert "Close" in df.columns
|
||||
assert "Date" in df.columns
|
||||
# 将df保存为文件,将文件路径传入到下一个action
|
||||
df_path = "./tests/data/stock_indicators.csv"
|
||||
df.to_csv(df_path)
|
||||
assert Path(df_path).is_file()
|
||||
# 可视化指标结果
|
||||
figure_path = "./tests/data/figure_ci.png"
|
||||
ci_ploter = OpenCodeInterpreter()
|
||||
ci_ploter.chat(
|
||||
f"使用seaborn对{df_path}中与股票布林带有关的数据列的Date, Close, SMA, BB_upper(布林带上界), BB_lower(布林带下界)进行可视化, 可视化图片保存在{figure_path}中。不需要任何指标计算,把Date列转换为日期类型。要求图片优美,BB_upper, BB_lower之间使用合适的颜色填充。"
|
||||
)
|
||||
assert Path(figure_path).is_file()
|
||||
# from pathlib import Path
|
||||
#
|
||||
# import pandas as pd
|
||||
# import pytest
|
||||
#
|
||||
# from metagpt.actions import Action
|
||||
# from metagpt.logs import logger
|
||||
# from metagpt.tools.code_interpreter import OpenCodeInterpreter, OpenInterpreterDecorator
|
||||
#
|
||||
# logger.add("./tests/data/test_ci.log")
|
||||
# stock = "./tests/data/baba_stock.csv"
|
||||
#
|
||||
#
|
||||
# # TODO: 需要一种表格数据格式,能够支持schame管理的,标注字段类型和字段含义。
|
||||
# class CreateStockIndicators(Action):
|
||||
# @OpenInterpreterDecorator(save_code=True, code_file_path="./tests/data/stock_indicators.py")
|
||||
# async def run(self, stock_path: str, indicators=["Simple Moving Average", "BollingerBands"]) -> pd.DataFrame:
|
||||
# """对stock_path中的股票数据, 使用pandas和ta计算indicators中的技术指标, 返回带有技术指标的股票数据,不需要去除空值, 不需要安装任何包;
|
||||
# 指标生成对应的三列: SMA, BB_upper, BB_lower
|
||||
# """
|
||||
# ...
|
||||
#
|
||||
#
|
||||
# @pytest.mark.asyncio
|
||||
# async def test_actions():
|
||||
# # 计算指标
|
||||
# indicators = ["Simple Moving Average", "BollingerBands"]
|
||||
# stocker = CreateStockIndicators()
|
||||
# df, msg = await stocker.run(stock, indicators=indicators)
|
||||
# assert isinstance(df, pd.DataFrame)
|
||||
# assert "Close" in df.columns
|
||||
# assert "Date" in df.columns
|
||||
# # 将df保存为文件,将文件路径传入到下一个action
|
||||
# df_path = "./tests/data/stock_indicators.csv"
|
||||
# df.to_csv(df_path)
|
||||
# assert Path(df_path).is_file()
|
||||
# # 可视化指标结果
|
||||
# figure_path = "./tests/data/figure_ci.png"
|
||||
# ci_ploter = OpenCodeInterpreter()
|
||||
# ci_ploter.chat(
|
||||
# f"使用seaborn对{df_path}中与股票布林带有关的数据列的Date, Close, SMA, BB_upper(布林带上界), BB_lower(布林带下界)进行可视化, 可视化图片保存在{figure_path}中。不需要任何指标计算,把Date列转换为日期类型。要求图片优美,BB_upper, BB_lower之间使用合适的颜色填充。"
|
||||
# )
|
||||
# assert Path(figure_path).is_file()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue