diff --git a/tests/metagpt/actions/test_invoice_ocr.py b/tests/metagpt/actions/test_invoice_ocr.py index ddadda7e6..7f16aa9a4 100644 --- a/tests/metagpt/actions/test_invoice_ocr.py +++ b/tests/metagpt/actions/test_invoice_ocr.py @@ -1,58 +1,58 @@ -# #!/usr/bin/env python3 -# # _*_ coding: utf-8 _*_ -# -# """ -# @Time : 2023/10/09 18:40:34 -# @Author : Stitch-z -# @File : test_invoice_ocr.py -# """ -# -# import os -# from pathlib import Path -# -# import pytest -# -# from metagpt.actions.invoice_ocr import GenerateTable, InvoiceOCR, ReplyQuestion -# -# -# @pytest.mark.asyncio -# @pytest.mark.parametrize( -# "invoice_path", -# [ -# "../../data/invoices/invoice-3.jpg", -# "../../data/invoices/invoice-4.zip", -# ], -# ) -# async def test_invoice_ocr(invoice_path: str): -# invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path)) -# filename = os.path.basename(invoice_path) -# resp = await InvoiceOCR().run(file_path=Path(invoice_path), filename=filename) -# assert isinstance(resp, list) -# -# -# @pytest.mark.asyncio -# @pytest.mark.parametrize( -# ("invoice_path", "expected_result"), -# [ -# ("../../data/invoices/invoice-1.pdf", [{"收款人": "小明", "城市": "深圳市", "总费用/元": "412.00", "开票日期": "2023年02月03日"}]), -# ], -# ) -# async def test_generate_table(invoice_path: str, expected_result: list[dict]): -# invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path)) -# filename = os.path.basename(invoice_path) -# ocr_result = await InvoiceOCR().run(file_path=Path(invoice_path), filename=filename) -# table_data = await GenerateTable().run(ocr_results=ocr_result, filename=filename) -# assert table_data == expected_result -# -# -# @pytest.mark.asyncio -# @pytest.mark.parametrize( -# ("invoice_path", "query", "expected_result"), -# [("../../data/invoices/invoice-1.pdf", "Invoicing date", "2023年02月03日")], -# ) -# async def test_reply_question(invoice_path: str, query: dict, expected_result: str): -# invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path)) -# filename = os.path.basename(invoice_path) -# ocr_result = await InvoiceOCR().run(file_path=Path(invoice_path), filename=filename) -# result = await ReplyQuestion().run(query=query, ocr_result=ocr_result) -# assert expected_result in result +#!/usr/bin/env python3 +# _*_ coding: utf-8 _*_ + +""" +@Time : 2023/10/09 18:40:34 +@Author : Stitch-z +@File : test_invoice_ocr.py +""" + +import os +from pathlib import Path + +import pytest + +from metagpt.actions.invoice_ocr import GenerateTable, InvoiceOCR, ReplyQuestion + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "invoice_path", + [ + "../../data/invoices/invoice-3.jpg", + "../../data/invoices/invoice-4.zip", + ], +) +async def test_invoice_ocr(invoice_path: str): + invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path)) + filename = os.path.basename(invoice_path) + resp = await InvoiceOCR().run(file_path=Path(invoice_path), filename=filename) + assert isinstance(resp, list) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("invoice_path", "expected_result"), + [ + ("../../data/invoices/invoice-1.pdf", [{"收款人": "小明", "城市": "深圳市", "总费用/元": "412.00", "开票日期": "2023年02月03日"}]), + ], +) +async def test_generate_table(invoice_path: str, expected_result: list[dict]): + invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path)) + filename = os.path.basename(invoice_path) + ocr_result = await InvoiceOCR().run(file_path=Path(invoice_path), filename=filename) + table_data = await GenerateTable().run(ocr_results=ocr_result, filename=filename) + assert table_data == expected_result + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("invoice_path", "query", "expected_result"), + [("../../data/invoices/invoice-1.pdf", "Invoicing date", "2023年02月03日")], +) +async def test_reply_question(invoice_path: str, query: dict, expected_result: str): + invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path)) + filename = os.path.basename(invoice_path) + ocr_result = await InvoiceOCR().run(file_path=Path(invoice_path), filename=filename) + result = await ReplyQuestion().run(query=query, ocr_result=ocr_result) + assert expected_result in result diff --git a/tests/metagpt/roles/test_invoice_ocr_assistant.py b/tests/metagpt/roles/test_invoice_ocr_assistant.py index e90182dde..ab3092004 100644 --- a/tests/metagpt/roles/test_invoice_ocr_assistant.py +++ b/tests/metagpt/roles/test_invoice_ocr_assistant.py @@ -1,63 +1,63 @@ -# #!/usr/bin/env python3 -# # _*_ coding: utf-8 _*_ -# -# """ -# @Time : 2023/9/21 23:11:27 -# @Author : Stitch-z -# @File : test_invoice_ocr_assistant.py -# """ -# -# import json -# from pathlib import Path -# -# import pandas as pd -# import pytest -# -# from metagpt.roles.invoice_ocr_assistant import InvoiceOCRAssistant, InvoicePath -# from metagpt.schema import Message -# -# -# @pytest.mark.asyncio -# @pytest.mark.parametrize( -# ("query", "invoice_path", "invoice_table_path", "expected_result"), -# [ -# ( -# "Invoicing date", -# Path("../../data/invoices/invoice-1.pdf"), -# Path("../../../data/invoice_table/invoice-1.xlsx"), -# [{"收款人": "小明", "城市": "深圳市", "总费用/元": 412.00, "开票日期": "2023年02月03日"}], -# ), -# ( -# "Invoicing date", -# Path("../../data/invoices/invoice-2.png"), -# Path("../../../data/invoice_table/invoice-2.xlsx"), -# [{"收款人": "铁头", "城市": "广州市", "总费用/元": 898.00, "开票日期": "2023年03月17日"}], -# ), -# ( -# "Invoicing date", -# Path("../../data/invoices/invoice-3.jpg"), -# Path("../../../data/invoice_table/invoice-3.xlsx"), -# [{"收款人": "夏天", "城市": "福州市", "总费用/元": 2462.00, "开票日期": "2023年08月26日"}], -# ), -# ( -# "Invoicing date", -# Path("../../data/invoices/invoice-4.zip"), -# Path("../../../data/invoice_table/invoice-4.xlsx"), -# [ -# {"收款人": "小明", "城市": "深圳市", "总费用/元": 412.00, "开票日期": "2023年02月03日"}, -# {"收款人": "铁头", "城市": "广州市", "总费用/元": 898.00, "开票日期": "2023年03月17日"}, -# {"收款人": "夏天", "城市": "福州市", "总费用/元": 2462.00, "开票日期": "2023年08月26日"}, -# ], -# ), -# ], -# ) -# async def test_invoice_ocr_assistant( -# query: str, invoice_path: Path, invoice_table_path: Path, expected_result: list[dict] -# ): -# invoice_path = Path.cwd() / invoice_path -# role = InvoiceOCRAssistant() -# await role.run(Message(content=query, instruct_content=InvoicePath(file_path=invoice_path))) -# invoice_table_path = Path.cwd() / invoice_table_path -# df = pd.read_excel(invoice_table_path) -# dict_result = df.to_dict(orient="records") -# assert json.dumps(dict_result) == json.dumps(expected_result) +#!/usr/bin/env python3 +# _*_ coding: utf-8 _*_ + +""" +@Time : 2023/9/21 23:11:27 +@Author : Stitch-z +@File : test_invoice_ocr_assistant.py +""" + +import json +from pathlib import Path + +import pandas as pd +import pytest + +from metagpt.roles.invoice_ocr_assistant import InvoiceOCRAssistant, InvoicePath +from metagpt.schema import Message + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("query", "invoice_path", "invoice_table_path", "expected_result"), + [ + ( + "Invoicing date", + Path("../../data/invoices/invoice-1.pdf"), + Path("../../../data/invoice_table/invoice-1.xlsx"), + [{"收款人": "小明", "城市": "深圳市", "总费用/元": 412.00, "开票日期": "2023年02月03日"}], + ), + ( + "Invoicing date", + Path("../../data/invoices/invoice-2.png"), + Path("../../../data/invoice_table/invoice-2.xlsx"), + [{"收款人": "铁头", "城市": "广州市", "总费用/元": 898.00, "开票日期": "2023年03月17日"}], + ), + ( + "Invoicing date", + Path("../../data/invoices/invoice-3.jpg"), + Path("../../../data/invoice_table/invoice-3.xlsx"), + [{"收款人": "夏天", "城市": "福州市", "总费用/元": 2462.00, "开票日期": "2023年08月26日"}], + ), + ( + "Invoicing date", + Path("../../data/invoices/invoice-4.zip"), + Path("../../../data/invoice_table/invoice-4.xlsx"), + [ + {"收款人": "小明", "城市": "深圳市", "总费用/元": 412.00, "开票日期": "2023年02月03日"}, + {"收款人": "铁头", "城市": "广州市", "总费用/元": 898.00, "开票日期": "2023年03月17日"}, + {"收款人": "夏天", "城市": "福州市", "总费用/元": 2462.00, "开票日期": "2023年08月26日"}, + ], + ), + ], +) +async def test_invoice_ocr_assistant( + query: str, invoice_path: Path, invoice_table_path: Path, expected_result: list[dict] +): + invoice_path = Path.cwd() / invoice_path + role = InvoiceOCRAssistant() + await role.run(Message(content=query, instruct_content=InvoicePath(file_path=invoice_path))) + invoice_table_path = Path.cwd() / invoice_table_path + df = pd.read_excel(invoice_table_path) + dict_result = df.to_dict(orient="records") + assert json.dumps(dict_result) == json.dumps(expected_result)