diff --git a/metagpt/const.py b/metagpt/const.py index 5e149ed72..a57be641b 100644 --- a/metagpt/const.py +++ b/metagpt/const.py @@ -53,6 +53,7 @@ DEFAULT_WORKSPACE_ROOT = METAGPT_ROOT / "workspace" EXAMPLE_PATH = METAGPT_ROOT / "examples" DATA_PATH = METAGPT_ROOT / "data" +TEST_DATA_PATH = METAGPT_ROOT / "tests/data" RESEARCH_PATH = DATA_PATH / "research" TUTORIAL_PATH = DATA_PATH / "tutorial_docx" INVOICE_OCR_TABLE_PATH = DATA_PATH / "invoice_table" diff --git a/tests/metagpt/actions/test_invoice_ocr.py b/tests/metagpt/actions/test_invoice_ocr.py index d569fda21..3dc233686 100644 --- a/tests/metagpt/actions/test_invoice_ocr.py +++ b/tests/metagpt/actions/test_invoice_ocr.py @@ -6,27 +6,26 @@ @Author : Stitch-z @File : test_invoice_ocr.py """ -import json -import os + from pathlib import Path import pytest from metagpt.actions.invoice_ocr import GenerateTable, InvoiceOCR, ReplyQuestion +from metagpt.const import TEST_DATA_PATH @pytest.mark.asyncio @pytest.mark.parametrize( "invoice_path", [ - "../../data/invoices/invoice-3.jpg", - # "../../data/invoices/invoice-4.zip", + Path("invoices/invoice-3.jpg"), + Path("invoices/invoice-4.zip"), ], ) -async def test_invoice_ocr(invoice_path: str): - invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path)) - filename = os.path.basename(invoice_path) - resp = await InvoiceOCR().run(file_path=Path(invoice_path), filename=filename) +async def test_invoice_ocr(invoice_path: Path): + invoice_path = TEST_DATA_PATH / invoice_path + resp = await InvoiceOCR().run(file_path=Path(invoice_path)) assert isinstance(resp, list) @@ -34,25 +33,32 @@ async def test_invoice_ocr(invoice_path: str): @pytest.mark.parametrize( ("invoice_path", "expected_result"), [ - ("../../data/invoices/invoice-1.pdf", [{"收款人": "小明", "城市": "深圳市", "总费用/元": "412.00", "开票日期": "2023年02月03日"}]), + ( + Path("invoices/invoice-1.pdf"), + {"收款人": "小明", "城市": "深圳", "总费用/元": 412.00, "开票日期": "2023年02月03日"} + ), ], ) -async def test_generate_table(invoice_path: str, expected_result: list[dict]): - invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path)) - filename = os.path.basename(invoice_path) - ocr_result = await InvoiceOCR().run(file_path=Path(invoice_path), filename=filename) +async def test_generate_table(invoice_path: Path, expected_result: dict): + invoice_path = TEST_DATA_PATH / invoice_path + filename = invoice_path.name + ocr_result = await InvoiceOCR().run(file_path=Path(invoice_path)) table_data = await GenerateTable().run(ocr_results=ocr_result, filename=filename) - assert json.dumps(table_data) == json.dumps(expected_result) + assert isinstance(table_data, list) + table_data = table_data[0] + assert expected_result["收款人"] == table_data["收款人"] + assert expected_result["城市"] in table_data["城市"] + assert float(expected_result["总费用/元"]) == float(table_data["总费用/元"]) + assert expected_result["开票日期"] == table_data["开票日期"] @pytest.mark.asyncio @pytest.mark.parametrize( ("invoice_path", "query", "expected_result"), - [("../../data/invoices/invoice-1.pdf", "Invoicing date", "2023年02月03日")], + [(Path("invoices/invoice-1.pdf"), "Invoicing date", "2023年02月03日")], ) -async def test_reply_question(invoice_path: str, query: dict, expected_result: str): - invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path)) - filename = os.path.basename(invoice_path) - ocr_result = await InvoiceOCR().run(file_path=Path(invoice_path), filename=filename) +async def test_reply_question(invoice_path: Path, query: dict, expected_result: str): + invoice_path = TEST_DATA_PATH / invoice_path + ocr_result = await InvoiceOCR().run(file_path=Path(invoice_path)) result = await ReplyQuestion().run(query=query, ocr_result=ocr_result) assert expected_result in result diff --git a/tests/metagpt/roles/test_invoice_ocr_assistant.py b/tests/metagpt/roles/test_invoice_ocr_assistant.py index 500d93a77..11b993dc0 100644 --- a/tests/metagpt/roles/test_invoice_ocr_assistant.py +++ b/tests/metagpt/roles/test_invoice_ocr_assistant.py @@ -12,6 +12,7 @@ from pathlib import Path import pandas as pd import pytest +from metagpt.const import TEST_DATA_PATH, DATA_PATH from metagpt.roles.invoice_ocr_assistant import InvoiceOCRAssistant, InvoicePath from metagpt.schema import Message @@ -22,29 +23,29 @@ from metagpt.schema import Message [ ( "Invoicing date", - Path("../../data/invoices/invoice-1.pdf"), - Path("../../../data/invoice_table/invoice-1.xlsx"), + Path("invoices/invoice-1.pdf"), + Path("invoice_table/invoice-1.xlsx"), {"收款人": "小明", "城市": "深圳", "总费用/元": 412.00, "开票日期": "2023年02月03日"}, ), ( "Invoicing date", - Path("../../data/invoices/invoice-2.png"), - Path("../../../data/invoice_table/invoice-2.xlsx"), + Path("invoices/invoice-2.png"), + Path("invoice_table/invoice-2.xlsx"), {"收款人": "铁头", "城市": "广州", "总费用/元": 898.00, "开票日期": "2023年03月17日"}, ), ( "Invoicing date", - Path("../../data/invoices/invoice-3.jpg"), - Path("../../../data/invoice_table/invoice-3.xlsx"), + Path("invoices/invoice-3.jpg"), + Path("invoice_table/invoice-3.xlsx"), {"收款人": "夏天", "城市": "福州", "总费用/元": 2462.00, "开票日期": "2023年08月26日"}, ), ], ) async def test_invoice_ocr_assistant(query: str, invoice_path: Path, invoice_table_path: Path, expected_result: dict): - invoice_path = Path.cwd() / invoice_path + invoice_path = TEST_DATA_PATH / invoice_path role = InvoiceOCRAssistant() await role.run(Message(content=query, instruct_content=InvoicePath(file_path=invoice_path))) - invoice_table_path = Path.cwd() / invoice_table_path + invoice_table_path = DATA_PATH / invoice_table_path df = pd.read_excel(invoice_table_path) resp = df.to_dict(orient="records") assert isinstance(resp, list) @@ -52,5 +53,5 @@ async def test_invoice_ocr_assistant(query: str, invoice_path: Path, invoice_tab resp = resp[0] assert expected_result["收款人"] == resp["收款人"] assert expected_result["城市"] in resp["城市"] - assert int(expected_result["总费用/元"]) == int(resp["总费用/元"]) + assert float(expected_result["总费用/元"]) == float(resp["总费用/元"]) assert expected_result["开票日期"] == resp["开票日期"] diff --git a/tests/metagpt/roles/test_tutorial_assistant.py b/tests/metagpt/roles/test_tutorial_assistant.py index ca54aaff5..0e6c1efb9 100644 --- a/tests/metagpt/roles/test_tutorial_assistant.py +++ b/tests/metagpt/roles/test_tutorial_assistant.py @@ -5,7 +5,6 @@ @Author : Stitch-z @File : test_tutorial_assistant.py """ -import shutil import aiofiles import pytest @@ -17,8 +16,6 @@ from metagpt.roles.tutorial_assistant import TutorialAssistant @pytest.mark.asyncio @pytest.mark.parametrize(("language", "topic"), [("Chinese", "Write a tutorial about pip")]) async def test_tutorial_assistant(language: str, topic: str): - shutil.rmtree(path=TUTORIAL_PATH, ignore_errors=True) - role = TutorialAssistant(language=language) msg = await role.run(topic) assert TUTORIAL_PATH.exists()