Update: 发票ocr助手单测数据路径改为从const获取

This commit is contained in:
Stitch-z 2023-12-28 22:32:40 +08:00
parent 19bc5f57b2
commit ac6ec8e152
4 changed files with 36 additions and 31 deletions

View file

@ -6,27 +6,26 @@
@Author : Stitch-z
@File : test_invoice_ocr.py
"""
import json
import os
from pathlib import Path
import pytest
from metagpt.actions.invoice_ocr import GenerateTable, InvoiceOCR, ReplyQuestion
from metagpt.const import TEST_DATA_PATH
@pytest.mark.asyncio
@pytest.mark.parametrize(
"invoice_path",
[
"../../data/invoices/invoice-3.jpg",
# "../../data/invoices/invoice-4.zip",
Path("invoices/invoice-3.jpg"),
Path("invoices/invoice-4.zip"),
],
)
async def test_invoice_ocr(invoice_path: str):
invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path))
filename = os.path.basename(invoice_path)
resp = await InvoiceOCR().run(file_path=Path(invoice_path), filename=filename)
async def test_invoice_ocr(invoice_path: Path):
invoice_path = TEST_DATA_PATH / invoice_path
resp = await InvoiceOCR().run(file_path=Path(invoice_path))
assert isinstance(resp, list)
@ -34,25 +33,32 @@ async def test_invoice_ocr(invoice_path: str):
@pytest.mark.parametrize(
("invoice_path", "expected_result"),
[
("../../data/invoices/invoice-1.pdf", [{"收款人": "小明", "城市": "深圳市", "总费用/元": "412.00", "开票日期": "2023年02月03日"}]),
(
Path("invoices/invoice-1.pdf"),
{"收款人": "小明", "城市": "深圳", "总费用/元": 412.00, "开票日期": "2023年02月03日"}
),
],
)
async def test_generate_table(invoice_path: str, expected_result: list[dict]):
invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path))
filename = os.path.basename(invoice_path)
ocr_result = await InvoiceOCR().run(file_path=Path(invoice_path), filename=filename)
async def test_generate_table(invoice_path: Path, expected_result: dict):
invoice_path = TEST_DATA_PATH / invoice_path
filename = invoice_path.name
ocr_result = await InvoiceOCR().run(file_path=Path(invoice_path))
table_data = await GenerateTable().run(ocr_results=ocr_result, filename=filename)
assert json.dumps(table_data) == json.dumps(expected_result)
assert isinstance(table_data, list)
table_data = table_data[0]
assert expected_result["收款人"] == table_data["收款人"]
assert expected_result["城市"] in table_data["城市"]
assert float(expected_result["总费用/元"]) == float(table_data["总费用/元"])
assert expected_result["开票日期"] == table_data["开票日期"]
@pytest.mark.asyncio
@pytest.mark.parametrize(
("invoice_path", "query", "expected_result"),
[("../../data/invoices/invoice-1.pdf", "Invoicing date", "2023年02月03日")],
[(Path("invoices/invoice-1.pdf"), "Invoicing date", "2023年02月03日")],
)
async def test_reply_question(invoice_path: str, query: dict, expected_result: str):
invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path))
filename = os.path.basename(invoice_path)
ocr_result = await InvoiceOCR().run(file_path=Path(invoice_path), filename=filename)
async def test_reply_question(invoice_path: Path, query: dict, expected_result: str):
invoice_path = TEST_DATA_PATH / invoice_path
ocr_result = await InvoiceOCR().run(file_path=Path(invoice_path))
result = await ReplyQuestion().run(query=query, ocr_result=ocr_result)
assert expected_result in result

View file

@ -12,6 +12,7 @@ from pathlib import Path
import pandas as pd
import pytest
from metagpt.const import TEST_DATA_PATH, DATA_PATH
from metagpt.roles.invoice_ocr_assistant import InvoiceOCRAssistant, InvoicePath
from metagpt.schema import Message
@ -22,29 +23,29 @@ from metagpt.schema import Message
[
(
"Invoicing date",
Path("../../data/invoices/invoice-1.pdf"),
Path("../../../data/invoice_table/invoice-1.xlsx"),
Path("invoices/invoice-1.pdf"),
Path("invoice_table/invoice-1.xlsx"),
{"收款人": "小明", "城市": "深圳", "总费用/元": 412.00, "开票日期": "2023年02月03日"},
),
(
"Invoicing date",
Path("../../data/invoices/invoice-2.png"),
Path("../../../data/invoice_table/invoice-2.xlsx"),
Path("invoices/invoice-2.png"),
Path("invoice_table/invoice-2.xlsx"),
{"收款人": "铁头", "城市": "广州", "总费用/元": 898.00, "开票日期": "2023年03月17日"},
),
(
"Invoicing date",
Path("../../data/invoices/invoice-3.jpg"),
Path("../../../data/invoice_table/invoice-3.xlsx"),
Path("invoices/invoice-3.jpg"),
Path("invoice_table/invoice-3.xlsx"),
{"收款人": "夏天", "城市": "福州", "总费用/元": 2462.00, "开票日期": "2023年08月26日"},
),
],
)
async def test_invoice_ocr_assistant(query: str, invoice_path: Path, invoice_table_path: Path, expected_result: dict):
invoice_path = Path.cwd() / invoice_path
invoice_path = TEST_DATA_PATH / invoice_path
role = InvoiceOCRAssistant()
await role.run(Message(content=query, instruct_content=InvoicePath(file_path=invoice_path)))
invoice_table_path = Path.cwd() / invoice_table_path
invoice_table_path = DATA_PATH / invoice_table_path
df = pd.read_excel(invoice_table_path)
resp = df.to_dict(orient="records")
assert isinstance(resp, list)
@ -52,5 +53,5 @@ async def test_invoice_ocr_assistant(query: str, invoice_path: Path, invoice_tab
resp = resp[0]
assert expected_result["收款人"] == resp["收款人"]
assert expected_result["城市"] in resp["城市"]
assert int(expected_result["总费用/元"]) == int(resp["总费用/元"])
assert float(expected_result["总费用/元"]) == float(resp["总费用/元"])
assert expected_result["开票日期"] == resp["开票日期"]

View file

@ -5,7 +5,6 @@
@Author : Stitch-z
@File : test_tutorial_assistant.py
"""
import shutil
import aiofiles
import pytest
@ -17,8 +16,6 @@ from metagpt.roles.tutorial_assistant import TutorialAssistant
@pytest.mark.asyncio
@pytest.mark.parametrize(("language", "topic"), [("Chinese", "Write a tutorial about pip")])
async def test_tutorial_assistant(language: str, topic: str):
shutil.rmtree(path=TUTORIAL_PATH, ignore_errors=True)
role = TutorialAssistant(language=language)
msg = await role.run(topic)
assert TUTORIAL_PATH.exists()