Update: improve the unit testing of tutorial assistants and OCR assistants.

This commit is contained in:
Stitch-z 2023-12-26 14:33:17 +08:00
parent 6432ed6e60
commit dc77a0d99b
2 changed files with 17 additions and 8 deletions

View file

@ -34,7 +34,10 @@ async def test_invoice_ocr(invoice_path: str):
@pytest.mark.parametrize(
("invoice_path", "expected_result"),
[
("../../data/invoices/invoice-1.pdf", [{"收款人": "小明", "城市": "深圳市", "总费用/元": "412.00", "开票日期": "2023年02月03日"}]),
(
"../../data/invoices/invoice-1.pdf",
[{"收款人": "小明", "城市": "深圳市", "总费用/元": "412.00", "开票日期": "2023年02月03日"}]
),
],
)
async def test_generate_table(invoice_path: str, expected_result: list[dict]):

View file

@ -7,7 +7,6 @@
@File : test_invoice_ocr_assistant.py
"""
import json
from pathlib import Path
import pandas as pd
@ -25,29 +24,36 @@ from metagpt.schema import Message
"Invoicing date",
Path("../../data/invoices/invoice-1.pdf"),
Path("../../../data/invoice_table/invoice-1.xlsx"),
[{"收款人": "小明", "城市": "深圳", "总费用/元": 412.00, "开票日期": "2023年02月03日"}],
{"收款人": "小明", "城市": "深圳", "总费用/元": 412.00, "开票日期": "2023年02月03日"},
),
(
"Invoicing date",
Path("../../data/invoices/invoice-2.png"),
Path("../../../data/invoice_table/invoice-2.xlsx"),
[{"收款人": "铁头", "城市": "广州", "总费用/元": 898.00, "开票日期": "2023年03月17日"}],
{"收款人": "铁头", "城市": "广州", "总费用/元": 898.00, "开票日期": "2023年03月17日"},
),
(
"Invoicing date",
Path("../../data/invoices/invoice-3.jpg"),
Path("../../../data/invoice_table/invoice-3.xlsx"),
[{"收款人": "夏天", "城市": "福州", "总费用/元": 2462.00, "开票日期": "2023年08月26日"}],
{"收款人": "夏天", "城市": "福州", "总费用/元": 2462.00, "开票日期": "2023年08月26日"},
)
],
)
async def test_invoice_ocr_assistant(
query: str, invoice_path: Path, invoice_table_path: Path, expected_result: list[dict]
query: str, invoice_path: Path, invoice_table_path: Path, expected_result: dict
):
invoice_path = Path.cwd() / invoice_path
role = InvoiceOCRAssistant()
await role.run(Message(content=query, instruct_content=InvoicePath(file_path=invoice_path)))
invoice_table_path = Path.cwd() / invoice_table_path
df = pd.read_excel(invoice_table_path)
dict_result = df.to_dict(orient="records")
assert json.dumps(dict_result) == json.dumps(expected_result)
resp = df.to_dict(orient="records")
assert isinstance(resp, list)
assert len(resp) == 1
resp = resp[0]
assert expected_result["收款人"] == resp["收款人"]
assert expected_result["城市"] in resp["城市"]
assert int(expected_result["总费用/元"]) == int(resp["总费用/元"])
assert expected_result["开票日期"] == resp["开票日期"]