diff --git a/tests/metagpt/actions/test_invoice_ocr.py b/tests/metagpt/actions/test_invoice_ocr.py index 12b1b4b30..b3b93cf9f 100644 --- a/tests/metagpt/actions/test_invoice_ocr.py +++ b/tests/metagpt/actions/test_invoice_ocr.py @@ -34,7 +34,10 @@ async def test_invoice_ocr(invoice_path: str): @pytest.mark.parametrize( ("invoice_path", "expected_result"), [ - ("../../data/invoices/invoice-1.pdf", [{"收款人": "小明", "城市": "深圳市", "总费用/元": "412.00", "开票日期": "2023年02月03日"}]), + ( + "../../data/invoices/invoice-1.pdf", + [{"收款人": "小明", "城市": "深圳市", "总费用/元": "412.00", "开票日期": "2023年02月03日"}] + ), ], ) async def test_generate_table(invoice_path: str, expected_result: list[dict]): diff --git a/tests/metagpt/roles/test_invoice_ocr_assistant.py b/tests/metagpt/roles/test_invoice_ocr_assistant.py index 38436fa60..48abb9eb8 100644 --- a/tests/metagpt/roles/test_invoice_ocr_assistant.py +++ b/tests/metagpt/roles/test_invoice_ocr_assistant.py @@ -7,7 +7,6 @@ @File : test_invoice_ocr_assistant.py """ -import json from pathlib import Path import pandas as pd @@ -25,29 +24,36 @@ from metagpt.schema import Message "Invoicing date", Path("../../data/invoices/invoice-1.pdf"), Path("../../../data/invoice_table/invoice-1.xlsx"), - [{"收款人": "小明", "城市": "深圳市", "总费用/元": 412.00, "开票日期": "2023年02月03日"}], + {"收款人": "小明", "城市": "深圳", "总费用/元": 412.00, "开票日期": "2023年02月03日"}, ), ( "Invoicing date", Path("../../data/invoices/invoice-2.png"), Path("../../../data/invoice_table/invoice-2.xlsx"), - [{"收款人": "铁头", "城市": "广州市", "总费用/元": 898.00, "开票日期": "2023年03月17日"}], + {"收款人": "铁头", "城市": "广州", "总费用/元": 898.00, "开票日期": "2023年03月17日"}, ), ( "Invoicing date", Path("../../data/invoices/invoice-3.jpg"), Path("../../../data/invoice_table/invoice-3.xlsx"), - [{"收款人": "夏天", "城市": "福州市", "总费用/元": 2462.00, "开票日期": "2023年08月26日"}], + {"收款人": "夏天", "城市": "福州", "总费用/元": 2462.00, "开票日期": "2023年08月26日"}, ) ], ) async def test_invoice_ocr_assistant( - query: str, invoice_path: Path, invoice_table_path: Path, expected_result: list[dict] + query: str, invoice_path: Path, invoice_table_path: Path, expected_result: dict ): invoice_path = Path.cwd() / invoice_path role = InvoiceOCRAssistant() await role.run(Message(content=query, instruct_content=InvoicePath(file_path=invoice_path))) invoice_table_path = Path.cwd() / invoice_table_path df = pd.read_excel(invoice_table_path) - dict_result = df.to_dict(orient="records") - assert json.dumps(dict_result) == json.dumps(expected_result) + resp = df.to_dict(orient="records") + assert isinstance(resp, list) + assert len(resp) == 1 + resp = resp[0] + assert expected_result["收款人"] == resp["收款人"] + assert expected_result["城市"] in resp["城市"] + assert int(expected_result["总费用/元"]) == int(resp["总费用/元"]) + assert expected_result["开票日期"] == resp["开票日期"] +